From 2f544c56788023ba30663ed881540d0d0a0a7529 Mon Sep 17 00:00:00 2001 From: olivroy <52606734+olivroy@users.noreply.github.com> Date: Tue, 5 Mar 2024 16:29:23 -0500 Subject: [PATCH] Upkeep (#33) * progress * fix typo * progress * Fix typo + add test + add testthat 3 * Fix ggplot2 deprecation warnings. linewidth + aes_string. * Add more informative names for tests * Fix dplyr deprecation warnings * Silence warnings as much as possible + update NEWS and DESCRIPTION * Update docs * Fix size bug. * Re-render docs * Update docs * Add `min_na()` and `max_na()` to avoid having many min / max warnings and coerce directly to NA inside `summarise()`. * add r cmd check + update readme * update docs * fix Note * Badge link correction --- .github/workflows/R-CMD-check.yaml | 29 +- DESCRIPTION | 14 +- NAMESPACE | 1 + NEWS.md | 17 +- R/measure_importance.R | 20 +- R/min_depth_distribution.R | 2 +- R/min_depth_interactions.R | 105 +- R/utils.R | 25 + README.md | 12 +- docs/404.html | 138 +- docs/articles/index.html | 127 +- docs/articles/randomForestExplainer.html | 668 +++++--- .../figure-html/unnamed-chunk-10-1.png | Bin 61040 -> 60907 bytes .../figure-html/unnamed-chunk-11-1.png | Bin 185846 -> 185613 bytes .../figure-html/unnamed-chunk-12-1.png | Bin 202917 -> 202667 bytes .../figure-html/unnamed-chunk-15-1.png | Bin 139678 -> 142474 bytes .../figure-html/unnamed-chunk-16-1.png | Bin 143668 -> 146722 bytes .../figure-html/unnamed-chunk-17-1.png | Bin 77237 -> 77626 bytes .../figure-html/unnamed-chunk-6-1.png | Bin 72371 -> 78002 bytes .../figure-html/unnamed-chunk-7-1.png | Bin 91514 -> 99915 bytes .../figure-html/unnamed-chunk-9-1.png | Bin 72201 -> 72457 bytes docs/authors.html | 172 +- docs/index.html | 111 +- docs/news/index.html | 167 +- docs/pkgdown.yml | 6 +- docs/reference/Rplot001.png | Bin 100322 -> 101393 bytes docs/reference/Rplot002.png | Bin 101332 -> 101696 bytes docs/reference/explain_forest.html | 250 +-- docs/reference/important_variables.html | 207 +-- docs/reference/index.html | 194 +-- docs/reference/measure_importance.html | 209 +-- docs/reference/min_depth_distribution.html | 1521 ++++++++--------- docs/reference/min_depth_interactions.html | 216 +-- docs/reference/plot_importance_ggpairs-1.png | Bin 71282 -> 66445 bytes docs/reference/plot_importance_ggpairs.html | 199 +-- docs/reference/plot_importance_rankings-1.png | Bin 66711 -> 66556 bytes docs/reference/plot_importance_rankings.html | 199 +-- .../plot_min_depth_distribution-1.png | Bin 53303 -> 54968 bytes .../plot_min_depth_distribution.html | 237 +-- .../plot_min_depth_interactions-1.png | Bin 119115 -> 119625 bytes .../plot_min_depth_interactions.html | 197 +-- .../reference/plot_multi_way_importance-1.png | Bin 56429 -> 56339 bytes docs/reference/plot_multi_way_importance.html | 237 +-- docs/reference/plot_predict_interaction-1.png | Bin 172619 -> 174818 bytes docs/reference/plot_predict_interaction-2.png | Bin 172858 -> 175126 bytes docs/reference/plot_predict_interaction.html | 240 +-- docs/sitemap.xml | 60 + tests/testthat/test-utils.R | 5 + tests/testthat/test_randomForest.R | 102 +- tests/testthat/test_ranger.R | 87 +- 50 files changed, 2389 insertions(+), 3385 deletions(-) create mode 100644 R/utils.R create mode 100644 docs/sitemap.xml create mode 100644 tests/testthat/test-utils.R diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index 0528262..a3ac618 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -1,4 +1,4 @@ -# Workflow derived from https://github.com/r-lib/actions/tree/master/examples +# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help on: push: @@ -18,7 +18,7 @@ jobs: fail-fast: false matrix: config: - - {os: macOS-latest, r: 'release'} + - {os: macos-latest, r: 'release'} - {os: windows-latest, r: 'release'} - {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'} - {os: ubuntu-latest, r: 'release'} @@ -29,30 +29,21 @@ jobs: R_KEEP_PKG_SOURCE: yes steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - - uses: r-lib/actions/setup-pandoc@v1 + - uses: r-lib/actions/setup-pandoc@v2 - - uses: r-lib/actions/setup-r@v1 + - uses: r-lib/actions/setup-r@v2 with: r-version: ${{ matrix.config.r }} http-user-agent: ${{ matrix.config.http-user-agent }} use-public-rspm: true - - uses: r-lib/actions/setup-r-dependencies@v1 + - uses: r-lib/actions/setup-r-dependencies@v2 with: - extra-packages: rcmdcheck + extra-packages: any::rcmdcheck + needs: check - - uses: r-lib/actions/check-r-package@v1 - - - name: Show testthat output - if: always() - run: find check -name 'testthat.Rout*' -exec cat '{}' \; || true - shell: bash - - - name: Upload check results - if: failure() - uses: actions/upload-artifact@main + - uses: r-lib/actions/check-r-package@v2 with: - name: ${{ runner.os }}-r${{ matrix.config.r }}-results - path: check + upload-snapshots: true diff --git a/DESCRIPTION b/DESCRIPTION index 605e1c0..95c725b 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: randomForestExplainer Title: Explaining and Visualizing Random Forests in Terms of Variable Importance -Version: 0.10.1 +Version: 0.10.2 Authors@R: c( person("Aleksandra", "Paluszynska", email = "ola.paluszynska@gmail.com", role = c("aut")), person("Przemyslaw", "Biecek", email = "przemyslaw.biecek@gmail.com", role = c("aut","ths")), @@ -10,22 +10,24 @@ Description: A set of tools to help explain which variables are most important i Depends: R (>= 3.0) License: GPL Encoding: UTF-8 -LazyData: true Imports: data.table (>= 1.10.4), dplyr (>= 0.7.1), DT (>= 0.2), GGally (>= 1.3.0), - ggplot2 (>= 2.2.1), + ggplot2 (>= 3.4.0), ggrepel (>= 0.6.5), randomForest (>= 4.6.12), ranger(>= 0.9.0), - reshape2 (>= 1.4.2), - rmarkdown (>= 1.5) + rlang, + rmarkdown (>= 1.5), + tidyr Suggests: knitr, MASS (>= 7.3.47), testthat VignetteBuilder: knitr -RoxygenNote: 7.1.0 +RoxygenNote: 7.3.0 URL: https://github.com/ModelOriented/randomForestExplainer, https://modeloriented.github.io/randomForestExplainer/ +Config/testthat/edition: 3 +Config/Needs/website: ModelOriented/DrWhyTemplate diff --git a/NAMESPACE b/NAMESPACE index 680841b..bb42932 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -26,6 +26,7 @@ import(ggplot2) import(ggrepel) importFrom(data.table,frankv) importFrom(data.table,rbindlist) +importFrom(rlang,.data) importFrom(stats,as.formula) importFrom(stats,predict) importFrom(stats,terms) diff --git a/NEWS.md b/NEWS.md index 7b9992f..e99f98e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,9 +1,22 @@ +# randomForestExplainer 0.10.2 + +* Remove dependency on reshape2 in favour of tidyr (@olivroy, #33) + +* Silence deprecation warnings from ggplot2 and dplyr (@olivroy, #29) + +* Use testthat 3rd edition. (@olivroy, #33) + +# randomForestExplainer 0.10.1 + +* Small tweaks to `explain_forest()`. + # randomForestExplainer 0.10.0 + ## New features * Added support for ranger forests. * Added support for unsupervised randomForest. * Added tests for most functions. ## Bug fixes -* Fixed bug for explain_forest not finding templates. -* Added more intuitive error message for explain_forest when local importance is absent. +* Fixed bug for `explain_forest()` not finding templates. +* Added more intuitive error message for `explain_forest()` when local `importance` is absent. diff --git a/R/measure_importance.R b/R/measure_importance.R index b3a16e9..5773916 100644 --- a/R/measure_importance.R +++ b/R/measure_importance.R @@ -10,8 +10,7 @@ measure_min_depth <- function(min_depth_frame, mean_sample){ # randomForest measure_no_of_nodes <- function(forest_table){ `split var` <- NULL - frame <- dplyr::group_by(forest_table, `split var`) %>% dplyr::summarize(n()) - colnames(frame) <- c("variable", "no_of_nodes") + frame <- dplyr::group_by(forest_table, variable = `split var`) %>% dplyr::summarize(no_of_nodes = dplyr::n()) frame <- as.data.frame(frame[!is.na(frame$variable),]) frame$variable <- as.character(frame$variable) return(frame) @@ -21,8 +20,7 @@ measure_no_of_nodes <- function(forest_table){ # randomForest measure_no_of_nodes_ranger <- function(forest_table){ splitvarName <- NULL - frame <- dplyr::group_by(forest_table, splitvarName) %>% dplyr::summarize(n()) - colnames(frame) <- c("variable", "no_of_nodes") + frame <- dplyr::group_by(forest_table, variable = splitvarName) %>% dplyr::summarize(no_of_nodes = n()) frame <- as.data.frame(frame[!is.na(frame$variable),]) frame$variable <- as.character(frame$variable) return(frame) @@ -75,8 +73,7 @@ measure_vimp_ranger <- function(forest){ measure_no_of_trees <- function(min_depth_frame){ variable <- NULL frame <- dplyr::group_by(min_depth_frame, variable) %>% - dplyr::summarize(count = n()) %>% as.data.frame() - colnames(frame)[2] <- "no_of_trees" + dplyr::summarize(no_of_trees = n()) %>% as.data.frame() frame$variable <- as.character(frame$variable) return(frame) } @@ -85,8 +82,7 @@ measure_no_of_trees <- function(min_depth_frame){ measure_times_a_root <- function(min_depth_frame){ variable <- NULL frame <- min_depth_frame[min_depth_frame$minimal_depth == 0, ] %>% - dplyr::group_by(variable) %>% dplyr::summarize(count = n()) %>% as.data.frame() - colnames(frame)[2] <- "times_a_root" + dplyr::group_by(variable) %>% dplyr::summarize(times_a_root = n()) %>% as.data.frame() frame$variable <- as.character(frame$variable) return(frame) } @@ -329,13 +325,13 @@ plot_multi_way_importance <- function(importance_frame, x_measure = "mean_min_de if(size_measure == "p_value"){ data$p_value <- cut(data$p_value, breaks = c(-Inf, 0.01, 0.05, 0.1, Inf), labels = c("<0.01", "[0.01, 0.05)", "[0.05, 0.1)", ">=0.1"), right = FALSE) - plot <- ggplot(data, aes_string(x = x_measure, y = y_measure)) + - geom_point(aes_string(color = size_measure), size = 3) + + plot <- ggplot(data, aes(x = .data[[x_measure]], y = .data[[y_measure]])) + + geom_point(aes(color = .data[[size_measure]]), size = 3) + geom_point(data = data_for_labels, color = "black", stroke = 2, aes(alpha = "top"), size = 3, shape = 21) + geom_label_repel(data = data_for_labels, aes(label = variable), show.legend = FALSE) + theme_bw() + scale_alpha_discrete(name = "variable", range = c(1, 1)) } else { - plot <- ggplot(data, aes_string(x = x_measure, y = y_measure, size = size_measure)) + + plot <- ggplot(data, aes(x = .data[[x_measure]], y = .data[[y_measure]], size = .data[[size_measure]])) + geom_point(aes(colour = "black")) + geom_point(data = data_for_labels, aes(colour = "blue")) + geom_label_repel(data = data_for_labels, aes(label = variable, size = NULL), show.legend = FALSE) + scale_colour_manual(name = "variable", values = c("black", "blue"), labels = c("non-top", "top")) + @@ -345,7 +341,7 @@ plot_multi_way_importance <- function(importance_frame, x_measure = "mean_min_de } } } else { - plot <- ggplot(data, aes_string(x = x_measure, y = y_measure)) + + plot <- ggplot(data, aes(x = .data[[x_measure]], y = .data[[y_measure]])) + geom_point(aes(colour = "black")) + geom_point(data = data_for_labels, aes(colour = "blue")) + geom_label_repel(data = data_for_labels, aes(label = variable, size = NULL), show.legend = FALSE) + scale_colour_manual(name = "variable", values = c("black", "blue"), labels = c("non-top", "top")) + diff --git a/R/min_depth_distribution.R b/R/min_depth_distribution.R index fd54f1c..5beb787 100644 --- a/R/min_depth_distribution.R +++ b/R/min_depth_distribution.R @@ -175,7 +175,7 @@ plot_min_depth_distribution <- function(min_depth_frame, k = 10, min_no_of_trees plot <- ggplot(data, aes(x = variable, y = count)) + geom_col(position = position_stack(reverse = TRUE), aes(fill = as.factor(minimal_depth))) + coord_flip() + scale_x_discrete(limits = rev(levels(data$variable))) + - geom_errorbar(aes(ymin = mean_minimal_depth_label, ymax = mean_minimal_depth_label), size = 1.5) + + geom_errorbar(aes(ymin = mean_minimal_depth_label, ymax = mean_minimal_depth_label), linewidth = 1.5) + xlab("Variable") + ylab("Number of trees") + guides(fill = guide_legend(title = "Minimal depth")) + theme_bw() + geom_label(data = data_for_labels, aes(y = mean_minimal_depth_label, label = mean_minimal_depth)) diff --git a/R/min_depth_interactions.R b/R/min_depth_interactions.R index 6627ac0..8f047c2 100644 --- a/R/min_depth_interactions.R +++ b/R/min_depth_interactions.R @@ -63,17 +63,22 @@ min_depth_interactions_values <- function(forest, vars){ mutate_if(is.factor, as.character) %>% calculate_tree_depth() %>% cbind(., tree = i, number = 1:nrow(.))) %>% data.table::rbindlist() %>% as.data.frame() - interactions_frame[vars] <- as.numeric(NA) + interactions_frame[vars] <- NA_real_ interactions_frame <- data.table::as.data.table(interactions_frame)[, conditional_depth(as.data.frame(.SD), vars), by = tree] %>% as.data.frame() mean_tree_depth <- dplyr::group_by(interactions_frame[, c("tree", vars)], tree) %>% - dplyr::summarize_at(vars, funs(max(., na.rm = TRUE))) %>% as.data.frame() - mean_tree_depth[mean_tree_depth == -Inf] <- NA + dplyr::summarise( + dplyr::across({{ vars }}, .fns = max_na) + ) %>% + as.data.frame() mean_tree_depth <- colMeans(mean_tree_depth[, vars, drop = FALSE], na.rm = TRUE) + min_depth_interactions_frame <- interactions_frame %>% dplyr::group_by(tree, `split var`) %>% - dplyr::summarize_at(vars, funs(min(., na.rm = TRUE))) %>% as.data.frame() - min_depth_interactions_frame[min_depth_interactions_frame == Inf] <- NA + dplyr::summarise( + dplyr::across({{ vars }}, .fns = min_na) + ) %>% + as.data.frame() min_depth_interactions_frame <- min_depth_interactions_frame[!is.na(min_depth_interactions_frame$`split var`), ] colnames(min_depth_interactions_frame)[2] <- "variable" min_depth_interactions_frame[, -c(1:2)] <- min_depth_interactions_frame[, -c(1:2)] - 1 @@ -88,19 +93,22 @@ min_depth_interactions_values_ranger <- function(forest, vars){ lapply(1:forest$num.trees, function(i) ranger::treeInfo(forest, tree = i) %>% calculate_tree_depth_ranger() %>% cbind(., tree = i, number = 1:nrow(.))) %>% data.table::rbindlist() %>% as.data.frame() - interactions_frame[vars] <- as.numeric(NA) + interactions_frame[vars] <- NA_real_ interactions_frame <- data.table::as.data.table(interactions_frame)[, conditional_depth_ranger(as.data.frame(.SD), vars), by = tree] %>% as.data.frame() mean_tree_depth <- dplyr::group_by(interactions_frame[, c("tree", vars)], tree) %>% - dplyr::summarize_at(vars, funs(max(., na.rm = TRUE))) %>% as.data.frame() - mean_tree_depth[mean_tree_depth == -Inf] <- NA + dplyr::summarise( + dplyr::across({{ vars }}, .fns = max_na) + ) %>% + as.data.frame() mean_tree_depth <- colMeans(mean_tree_depth[, vars, drop = FALSE], na.rm = TRUE) min_depth_interactions_frame <- - interactions_frame %>% dplyr::group_by(tree, splitvarName) %>% - dplyr::summarize_at(vars, funs(min(., na.rm = TRUE))) %>% as.data.frame() - min_depth_interactions_frame[min_depth_interactions_frame == Inf] <- NA - min_depth_interactions_frame <- min_depth_interactions_frame[!is.na(min_depth_interactions_frame$splitvarName), ] - colnames(min_depth_interactions_frame)[2] <- "variable" + interactions_frame %>% dplyr::group_by(tree, variable = splitvarName) %>% + dplyr::summarise( + dplyr::across(.cols = {{ vars }}, .fns = min_na) + ) %>% + as.data.frame() + min_depth_interactions_frame <- min_depth_interactions_frame[!is.na(min_depth_interactions_frame$variable), ] min_depth_interactions_frame[, -c(1:2)] <- min_depth_interactions_frame[, -c(1:2)] - 1 return(list(min_depth_interactions_frame, mean_tree_depth)) } @@ -137,11 +145,17 @@ min_depth_interactions.randomForest <- function(forest, vars = important_variabl min_depth_interactions_frame <- min_depth_interactions_frame[[1]] interactions_frame <- min_depth_interactions_frame %>% dplyr::group_by(variable) %>% - dplyr::summarize_at(vars, funs(mean(., na.rm = TRUE))) %>% as.data.frame() + dplyr::summarise( + dplyr::across({{ vars }}, function(x) mean(x, na.rm = TRUE)) + ) %>% + as.data.frame() interactions_frame[is.na(as.matrix(interactions_frame))] <- NA occurrences <- min_depth_interactions_frame %>% dplyr::group_by(variable) %>% - dplyr::summarize_at(vars, funs(sum(!is.na(.)))) %>% as.data.frame() + dplyr::summarise( + dplyr::across({{ vars }}, function(x) sum(!is.na(x))) + ) %>% + as.data.frame() if(mean_sample == "all_trees"){ non_occurrences <- occurrences non_occurrences[, -1] <- forest$ntree - occurrences[, -1] @@ -157,19 +171,26 @@ min_depth_interactions.randomForest <- function(forest, vars = important_variabl interactions_frame[, -1] <- (interactions_frame[, -1] * occurrences[, -1] + as.matrix(non_occurrences[, -1]) %*% diag(mean_tree_depth, nrow = length(mean_tree_depth)))/(forest$ntree - minimum_non_occurrences) } - interactions_frame <- reshape2::melt(interactions_frame, id.vars = "variable") - colnames(interactions_frame)[2:3] <- c("root_variable", "mean_min_depth") - occurrences <- reshape2::melt(occurrences, id.vars = "variable") - colnames(occurrences)[2:3] <- c("root_variable", "occurrences") + interactions_frame <- tidyr::pivot_longer( + interactions_frame, + cols = -"variable", + names_to = "root_variable", + values_to = "mean_min_depth" + ) + occurrences <- tidyr::pivot_longer( + occurrences, + cols = -"variable", + names_to = "root_variable", + values_to = "occurrences" + ) interactions_frame <- merge(interactions_frame, occurrences) interactions_frame$interaction <- paste(interactions_frame$root_variable, interactions_frame$variable, sep = ":") forest_table <- lapply(1:forest$ntree, function(i) randomForest::getTree(forest, k = i, labelVar = T) %>% mutate_if(is.factor, as.character) %>% calculate_tree_depth() %>% cbind(tree = i)) %>% rbindlist() - min_depth_frame <- dplyr::group_by(forest_table, tree, `split var`) %>% - dplyr::summarize(min(depth)) - colnames(min_depth_frame) <- c("tree", "variable", "minimal_depth") + min_depth_frame <- dplyr::group_by(forest_table, tree, variable = `split var`) %>% + dplyr::summarize(minimal_depth = min(depth)) min_depth_frame <- as.data.frame(min_depth_frame[!is.na(min_depth_frame$variable),]) importance_frame <- get_min_depth_means(min_depth_frame, min_depth_count(min_depth_frame), uncond_mean_sample) colnames(importance_frame)[2] <- "uncond_mean_min_depth" @@ -187,11 +208,17 @@ min_depth_interactions.ranger <- function(forest, vars = important_variables(mea min_depth_interactions_frame <- min_depth_interactions_frame[[1]] interactions_frame <- min_depth_interactions_frame %>% dplyr::group_by(variable) %>% - dplyr::summarize_at(vars, funs(mean(., na.rm = TRUE))) %>% as.data.frame() + dplyr::summarise( + dplyr::across({{ vars }}, function(x) mean(x, na.rm = TRUE)) + ) %>% + as.data.frame() interactions_frame[is.na(as.matrix(interactions_frame))] <- NA occurrences <- min_depth_interactions_frame %>% dplyr::group_by(variable) %>% - dplyr::summarize_at(vars, funs(sum(!is.na(.)))) %>% as.data.frame() + dplyr::summarise( + dplyr::across({{ vars }}, function(x) sum(!is.na(x), na.rm = TRUE)) + ) %>% + as.data.frame() if(mean_sample == "all_trees"){ non_occurrences <- occurrences non_occurrences[, -1] <- forest$num.trees - occurrences[, -1] @@ -207,18 +234,15 @@ min_depth_interactions.ranger <- function(forest, vars = important_variables(mea interactions_frame[, -1] <- (interactions_frame[, -1] * occurrences[, -1] + as.matrix(non_occurrences[, -1]) %*% diag(mean_tree_depth, nrow = length(mean_tree_depth)))/(forest$num.trees - minimum_non_occurrences) } - interactions_frame <- reshape2::melt(interactions_frame, id.vars = "variable") - colnames(interactions_frame)[2:3] <- c("root_variable", "mean_min_depth") - occurrences <- reshape2::melt(occurrences, id.vars = "variable") - colnames(occurrences)[2:3] <- c("root_variable", "occurrences") + interactions_frame <- tidyr::pivot_longer(interactions_frame, cols = -"variable", names_to = "root_variable", values_to = "mean_min_depth") + occurrences <- tidyr::pivot_longer(occurrences, cols = -"variable", names_to = "root_variable", values_to = "occurrences") interactions_frame <- merge(interactions_frame, occurrences) interactions_frame$interaction <- paste(interactions_frame$root_variable, interactions_frame$variable, sep = ":") forest_table <- lapply(1:forest$num.trees, function(i) ranger::treeInfo(forest, tree = i) %>% calculate_tree_depth_ranger() %>% cbind(tree = i)) %>% rbindlist() - min_depth_frame <- dplyr::group_by(forest_table, tree, splitvarName) %>% - dplyr::summarize(min(depth)) - colnames(min_depth_frame) <- c("tree", "variable", "minimal_depth") + min_depth_frame <- dplyr::group_by(forest_table, tree, variable = splitvarName) %>% + dplyr::summarize(minimal_depth = min(depth)) min_depth_frame <- as.data.frame(min_depth_frame[!is.na(min_depth_frame$variable),]) importance_frame <- get_min_depth_means(min_depth_frame, min_depth_count(min_depth_frame), uncond_mean_sample) colnames(importance_frame)[2] <- "uncond_mean_min_depth" @@ -251,14 +275,14 @@ plot_min_depth_interactions <- function(interactions_frame, k = 30, interactions_frame[ order(interactions_frame$occurrences, decreasing = TRUE), "interaction"]) minimum <- min(interactions_frame$mean_min_depth, na.rm = TRUE) - if(is.null(k)) k <- length(levels(interactions_frame$interaction)) + if(is.null(k)) k <- nlevels(interactions_frame$interaction) plot <- ggplot(interactions_frame[interactions_frame$interaction %in% levels(interactions_frame$interaction)[1:k] & !is.na(interactions_frame$mean_min_depth), ], aes(x = interaction, y = mean_min_depth, fill = occurrences)) + geom_bar(stat = "identity") + geom_pointrange(aes(ymin = pmin(mean_min_depth, uncond_mean_min_depth), y = uncond_mean_min_depth, ymax = pmax(mean_min_depth, uncond_mean_min_depth), shape = "unconditional"), fatten = 2, size = 1) + - geom_hline(aes(yintercept = minimum, linetype = "minimum"), color = "red", size = 1.5) + + geom_hline(aes(yintercept = minimum, linetype = "minimum"), color = "red", linewidth = 1.5) + scale_linetype_manual(name = NULL, values = 1) + theme_bw() + scale_shape_manual(name = NULL, values = 19) + theme(axis.text.x = element_text(angle = 45, hjust = 1)) @@ -324,7 +348,7 @@ plot_predict_interaction.randomForest <- function(forest, data, variable1, varia } if(forest$type == "regression"){ newdata$prediction <- predict(forest, newdata, type = "response") - plot <- ggplot(newdata, aes_string(x = variable1, y = variable2, fill = "prediction")) + + plot <- ggplot(newdata, aes(x = .data[[variable1]], y = .data[[variable2]], fill = prediction)) + geom_raster() + theme_bw() + scale_fill_gradient2(midpoint = min(newdata$prediction) + 0.5 * (max(newdata$prediction) - min(newdata$prediction)), low = "blue", high = "red") @@ -335,9 +359,9 @@ plot_predict_interaction.randomForest <- function(forest, data, variable1, varia } else { newdata[, paste0("probability_", forest$classes)] <- predict(forest, newdata, type = "prob") } - newdata <- reshape2::melt(newdata, id.vars = id_vars) + newdata <- tidyr::pivot_longer(newdata, cols = !dplyr::all_of(id_vars), names_to = "variable") newdata$prediction <- newdata$value - plot <- ggplot(newdata, aes_string(x = variable1, y = variable2, fill = "prediction")) + + plot <- ggplot(newdata, aes(x = .data[[variable1]], y = .data[[variable2]], fill = prediction)) + geom_raster() + theme_bw() + facet_wrap(~ variable) + scale_fill_gradient2(midpoint = min(newdata$prediction) + 0.5 * (max(newdata$prediction) - min(newdata$prediction)), low = "blue", high = "red") @@ -352,6 +376,7 @@ plot_predict_interaction.randomForest <- function(forest, data, variable1, varia #' @importFrom stats predict #' @importFrom stats terms #' @importFrom stats as.formula +#' @importFrom rlang .data #' @export plot_predict_interaction.ranger <- function(forest, data, variable1, variable2, grid = 100, main = paste0("Prediction of the forest for different values of ", @@ -372,7 +397,7 @@ plot_predict_interaction.ranger <- function(forest, data, variable1, variable2, } if(forest$treetype == "Regression"){ newdata$prediction <- predict(forest, newdata, type = "response")$predictions - plot <- ggplot(newdata, aes_string(x = variable1, y = variable2, fill = "prediction")) + + plot <- ggplot(newdata, aes(x = .data[[variable1]], y = .data[[variable2]], fill = prediction)) + geom_raster() + theme_bw() + scale_fill_gradient2(midpoint = min(newdata$prediction) + 0.5 * (max(newdata$prediction) - min(newdata$prediction)), low = "blue", high = "red") @@ -384,9 +409,9 @@ plot_predict_interaction.ranger <- function(forest, data, variable1, variable2, } else { newdata[, paste0("probability_", colnames(pred))] <- pred } - newdata <- reshape2::melt(newdata, id.vars = id_vars) + newdata <- tidyr::pivot_longer(newdata, cols = !dplyr::all_of(id_vars), names_to = "variable") newdata$prediction <- newdata$value - plot <- ggplot(newdata, aes_string(x = variable1, y = variable2, fill = "prediction")) + + plot <- ggplot(newdata, aes(x = .data[[variable1]], y = .data[[variable2]], fill = prediction)) + geom_raster() + theme_bw() + facet_wrap(~ variable) + scale_fill_gradient2(midpoint = min(newdata$prediction) + 0.5 * (max(newdata$prediction) - min(newdata$prediction)), low = "blue", high = "red") @@ -403,7 +428,7 @@ plot_predict_interaction.ranger <- function(forest, data, variable1, variable2, time <- new_time } newdata$prediction <- pred$survival[, pred$unique.death.times == time, drop = TRUE] - plot <- ggplot(newdata, aes_string(x = variable1, y = variable2, fill = "prediction")) + + plot <- ggplot(newdata, aes(x = .data[[variable1]], y = .data[[variable2]], fill = prediction)) + geom_raster() + theme_bw() + scale_fill_gradient2(midpoint = min(newdata$prediction) + 0.5 * (max(newdata$prediction) - min(newdata$prediction)), low = "blue", high = "red") diff --git a/R/utils.R b/R/utils.R new file mode 100644 index 0000000..14099c3 --- /dev/null +++ b/R/utils.R @@ -0,0 +1,25 @@ +# Helpers to avoid warnings in computations +# Are all values NA? +all_na <- function(x) { + if (!anyNA(x)) { + return(FALSE) + } + all(is.na(x)) +} +# Min but returns NA if only has NA +min_na <- function(x) { + if (all_na(x)) { + return(NA) + } + min(x, na.rm = TRUE) +} +# max but returns NA if only has NA +max_na <- function(x) { + if (all_na(x)) { + return(NA) + } + max(x, na.rm = TRUE) +} + +utils::globalVariables(c("prediction", "variable")) + diff --git a/README.md b/README.md index fcbd04a..12ae211 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,15 @@ # randomForestExplainer -[![CRAN_Status_Badge](http://www.r-pkg.org/badges/version/randomForestExplainer)](https://cran.r-project.org/package=randomForestExplainer) -[![R-CMD-check](https://github.com/ModelOriented/randomForestExplainer/workflows/R-CMD-check/badge.svg)](https://github.com/ModelOriented/randomForestExplainer/actions) -[![codecov](https://codecov.io/gh/ModelOriented/randomForestExplainer/branch/master/graph/badge.svg)](https://codecov.io/gh/ModelOriented/randomForestExplainer) + + + +[![CRAN status](https://www.r-pkg.org/badges/version/randomForestExplainer)](https://cran.r-project.org/package=randomForestExplainer) +[![R-CMD-check](https://github.com/ModelOriented/randomForestExplainer/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/ModelOriented/randomForestExplainer/actions/workflows/R-CMD-check.yaml) +[![codecov](https://codecov.io/gh/ModelOriented/randomForestExplainer/branch/master/graph/badge.svg)](https://app.codecov.io/gh/ModelOriented/randomForestExplainer?branch=master) [![DOI](https://zenodo.org/badge/97007621.svg)](https://zenodo.org/badge/latestdoi/97007621) + -A set of tools to understand what is happening inside a Random Forest. A detailed discussion of the package and importance measures it implements can be found here: [Master thesis on randomForestExplainer](https://cdn.staticaly.com/gh/geneticsMiNIng/BlackBoxOpener/master/randomForestExplainer_Master_thesis.pdf). +A set of tools to understand what is happening inside a Random Forest. A detailed discussion of the package and importance measures it implements can be found here: [Master thesis on randomForestExplainer](https://github.com/geneticsMiNIng/BlackBoxOpener/blob/master/randomForestExplainer_Master_thesis.pdf). ## Installation diff --git a/docs/404.html b/docs/404.html index d20490d..554e992 100644 --- a/docs/404.html +++ b/docs/404.html @@ -1,88 +1,40 @@ - - -
- + + + + -To further explore variable importance measures we pass our forest to measure_importance
function and get the following data frame (we save and load it from memory to save time):
To further explore variable importance measures we pass our forest to
+measure_importance
function and get the following data
+frame (we save and load it from memory to save time):
-# importance_frame <- measure_importance(forest)
-# save(importance_frame, file = "importance_frame.rda")
-load("importance_frame.rda")
-importance_frame
## variable mean_min_depth no_of_nodes mse_increase node_purity_increase
-## 1 age 3.308000 8936 3.7695582 1144.6959
-## 2 black 3.512000 7855 1.6677224 762.5438
-## 3 chas 6.591152 761 0.4931158 193.7997
-## 4 crim 2.386000 9434 8.8550476 2556.8119
-## 5 dis 2.600000 9210 7.5408462 2461.5665
-## 6 indus 3.166000 4182 7.5565917 3083.5072
-## 7 lstat 1.288000 11443 62.8221475 12401.4000
-## 8 nox 2.578000 6187 10.3991589 2625.6542
-## 9 ptratio 2.868000 4572 6.5315832 2269.6530
-## 10 rad 5.115968 2631 1.2258054 324.9312
-## 11 rm 1.346000 11394 34.8226290 12848.2579
-## 12 tax 3.556000 4402 3.5985825 1090.7962
-## 13 zn 6.087424 1529 0.6720070 300.3424
-## no_of_trees times_a_root p_value
-## 1 500 2 8.381103e-225
-## 2 500 1 5.822067e-81
-## 3 411 0 1.000000e+00
-## 4 500 23 6.498487e-313
-## 5 500 1 1.188152e-271
-## 6 500 96 1.000000e+00
-## 7 500 135 0.000000e+00
-## 8 500 36 9.833401e-01
-## 9 500 46 1.000000e+00
-## 10 499 4 1.000000e+00
-## 11 500 139 0.000000e+00
-## 12 500 11 1.000000e+00
-## 13 482 6 1.000000e+00
-It contains 13 rows, each corresponding to a predictor, and 8 columns of which one stores the variable names and the rest store the variable importance measures of a variable \(X_j\):
+# importance_frame <- measure_importance(forest)
+# save(importance_frame, file = "importance_frame.rda")
+load("importance_frame.rda")
+importance_frame
## variable mean_min_depth no_of_nodes mse_increase node_purity_increase
+## 1 age 3.308000 8936 3.7695582 1144.6959
+## 2 black 3.512000 7855 1.6677224 762.5438
+## 3 chas 6.591152 761 0.4931158 193.7997
+## 4 crim 2.386000 9434 8.8550476 2556.8119
+## 5 dis 2.600000 9210 7.5408462 2461.5665
+## 6 indus 3.166000 4182 7.5565917 3083.5072
+## 7 lstat 1.288000 11443 62.8221475 12401.4000
+## 8 nox 2.578000 6187 10.3991589 2625.6542
+## 9 ptratio 2.868000 4572 6.5315832 2269.6530
+## 10 rad 5.115968 2631 1.2258054 324.9312
+## 11 rm 1.346000 11394 34.8226290 12848.2579
+## 12 tax 3.556000 4402 3.5985825 1090.7962
+## 13 zn 6.087424 1529 0.6720070 300.3424
+## no_of_trees times_a_root p_value
+## 1 500 2 8.381103e-225
+## 2 500 1 5.822067e-81
+## 3 411 0 1.000000e+00
+## 4 500 23 6.498487e-313
+## 5 500 1 1.188152e-271
+## 6 500 96 1.000000e+00
+## 7 500 135 0.000000e+00
+## 8 500 36 9.833401e-01
+## 9 500 46 1.000000e+00
+## 10 499 4 1.000000e+00
+## 11 500 139 0.000000e+00
+## 12 500 11 1.000000e+00
+## 13 482 6 1.000000e+00
+It contains 13 rows, each corresponding to a predictor, and 8 columns +of which one stores the variable names and the rest store the variable +importance measures of a variable \(X_j\):
accuracy_decrease
(classification) – mean decrease of prediction accuracy after \(X_j\) is permuted,
gini_decrease
(classification) – mean decrease in the Gini index of node impurity (i.e. increase of node purity) by splits on \(X_j\),
mse_increase
(regression) – mean increase of mean squared error after \(X_j\) is permuted,
node_purity_increase
(regression) – mean node purity increase by splits on \(X_j\), as measured by the decrease in sum of squares,
mean_minimal_depth
– mean minimal depth calculated in one of three ways specified by the parameter mean_sample
,
no_of_trees
– total number of trees in which a split on \(X_j\) occurs,
no_of_nodes
– total number of nodes that use \(X_j\) for splitting (it is usually equal to no_of_trees
if trees are shallow),
times_a_root
– total number of trees in which \(X_j\) is used for splitting the root node (i.e., the whole sample is divided into two based on the value of \(X_j\)),
p_value
– \(p\)-value for the one-sided binomial test using the following distribution: \[Bin(\texttt{no_of_nodes},\ \mathbf{P}(\text{node splits on } X_j)),\] where we calculate the probability of split on \(X_j\) as if \(X_j\) was uniformly drawn from the \(r\) candidate variables \[\mathbf{P}(\text{node splits on } X_j) = \mathbf{P}(X_j \text{ is a candidate})\cdot\mathbf{P}(X_j \text{ is selected}) = \frac{r}{p}\cdot \frac{1}{r} = \frac{1}{p}.\] This test tells us whether the observed number of successes (number of nodes in which \(X_j\) was used for splitting) exceeds the theoretical number of successes if they were random (i.e. following the binomial distribution given above).
accuracy_decrease
(classification) – mean decrease
+of prediction accuracy after \(X_j\) is
+permuted,
gini_decrease
(classification) – mean decrease in
+the Gini index of node impurity (i.e. increase of node purity) by splits
+on \(X_j\),
mse_increase
(regression) – mean increase of mean
+squared error after \(X_j\) is
+permuted,
node_purity_increase
(regression) – mean node purity
+increase by splits on \(X_j\), as
+measured by the decrease in sum of squares,
mean_minimal_depth
– mean minimal depth calculated
+in one of three ways specified by the parameter
+mean_sample
,
no_of_trees
– total number of trees in which a split
+on \(X_j\) occurs,
no_of_nodes
– total number of nodes that use \(X_j\) for splitting (it is usually equal to
+no_of_trees
if trees are shallow),
times_a_root
– total number of trees in which \(X_j\) is used for splitting the root node
+(i.e., the whole sample is divided into two based on the value of \(X_j\)),
p_value
– \(p\)-value for the one-sided binomial test
+using the following distribution: \[Bin(\texttt{no_of_nodes},\ \mathbf{P}(\text{node
+splits on } X_j)),\] where we calculate the probability of split
+on \(X_j\) as if \(X_j\) was uniformly drawn from the \(r\) candidate variables \[\mathbf{P}(\text{node splits on } X_j) =
+\mathbf{P}(X_j \text{ is a candidate})\cdot\mathbf{P}(X_j \text{ is
+selected}) = \frac{r}{p}\cdot \frac{1}{r} = \frac{1}{p}.\] This
+test tells us whether the observed number of successes (number of nodes
+in which \(X_j\) was used for
+splitting) exceeds the theoretical number of successes if they were
+random (i.e. following the binomial distribution given above).
Measures (a)-(d) are calculated by the randomForest
package so need only to be extracted from our forest
object if option localImp = TRUE
was used for growing the forest (we assume this is the case). Note that measures (a) and (c) are based on the decrease in predictive accuracy of the forest after perturbation of the variable, (b) and (d) are based on changes in node purity after splits on the variable and (e)-(i) are based on the structure of the forest.
The function measure_importance
allows you to specify the method of calculating mean minimal depth (mean_sample
parameter, default "top_trees"
) and the measures to be calculated as a character vector a subset of names of measures given above (measures
parameter, default to NULL
leads to calculating all measures).
Below we present the result of plot_multi_way_importance
for the default values of x_measure
and y_measure
, which specify measures to use on \(x\) and \(y\)-axis, and the size of points reflects the number of nodes split on the variable. For problems with many variables we can restrict the plot to only those used for splitting in at least min_no_of_trees
trees. By default 10 top variables in the plot are highlighted in blue and labeled (no_of_labels
) – these are selected using the function important_variables
, i.e. using the sum of rankings based on importance measures used in the plot (more variables may be labeled if ties occur).
Measures (a)-(d) are calculated by the randomForest
+package so need only to be extracted from our forest
object
+if option localImp = TRUE
was used for growing the forest
+(we assume this is the case). Note that measures (a) and (c) are based
+on the decrease in predictive accuracy of the forest after perturbation
+of the variable, (b) and (d) are based on changes in node purity after
+splits on the variable and (e)-(i) are based on the structure of the
+forest.
The function measure_importance
allows you to specify
+the method of calculating mean minimal depth (mean_sample
+parameter, default "top_trees"
) and the measures to be
+calculated as a character vector a subset of names of measures given
+above (measures
parameter, default to NULL
+leads to calculating all measures).
Below we present the result of plot_multi_way_importance
+for the default values of x_measure
and
+y_measure
, which specify measures to use on \(x\) and \(y\)-axis, and the size of points reflects
+the number of nodes split on the variable. For problems with many
+variables we can restrict the plot to only those used for splitting in
+at least min_no_of_trees
trees. By default 10 top variables
+in the plot are highlighted in blue and labeled
+(no_of_labels
) – these are selected using the function
+important_variables
, i.e. using the sum of rankings based
+on importance measures used in the plot (more variables may be labeled
+if ties occur).
-# plot_multi_way_importance(forest, size_measure = "no_of_nodes") # gives the same result as below but takes longer
-plot_multi_way_importance(importance_frame, size_measure = "no_of_nodes")
# plot_multi_way_importance(forest, size_measure = "no_of_nodes") # gives the same result as below but takes longer
+plot_multi_way_importance(importance_frame, size_measure = "no_of_nodes")
Observe the marked negative relation between times_a_root
and mean_min_depth
. Also, the superiority of lstat
and rm
is clear in all three dimensions plotted (though it is not clear which of the two is better). Further, we present the multi-way importance plot for a different set of importance measures: increase of mean squared error after permutation (\(x\)-axis), increase in the node purity index (\(y\)-axis) and levels of significance (color of points). We also set no_of_labels
to five so that only five top variables will be highlighted (as ties occur, six are eventually labeled).
Observe the marked negative relation between
+times_a_root
and mean_min_depth
. Also, the
+superiority of lstat
and rm
is clear in all
+three dimensions plotted (though it is not clear which of the two is
+better). Further, we present the multi-way importance plot for a
+different set of importance measures: increase of mean squared error
+after permutation (\(x\)-axis),
+increase in the node purity index (\(y\)-axis) and levels of significance (color
+of points). We also set no_of_labels
to five so that only
+five top variables will be highlighted (as ties occur, six are
+eventually labeled).
-plot_multi_way_importance(importance_frame, x_measure = "mse_increase", y_measure = "node_purity_increase", size_measure = "p_value", no_of_labels = 5)
plot_multi_way_importance(importance_frame, x_measure = "mse_increase", y_measure = "node_purity_increase", size_measure = "p_value", no_of_labels = 5)
As in the previous plot, the two measures used as coordinates seem correlated, but in this case this is somewhat more surprising as one is connected to the structure of the forest and the other to its prediction, whereas in the previous plot both measures reflected the structure. Also, in this plot we see that although lstat
and rm
are similar in terms of node purity increase and \(p\)-value, the former is markedly better if we look at the increase in MSE. Interestingly, nox
and indus
are quite good when it comes to the two measures reflected on the axes, but are not significant according to our \(p\)-value, which is a derivative of the number of nodes that use a variable for splitting.
As in the previous plot, the two measures used as coordinates seem
+correlated, but in this case this is somewhat more surprising as one is
+connected to the structure of the forest and the other to its
+prediction, whereas in the previous plot both measures reflected the
+structure. Also, in this plot we see that although lstat
+and rm
are similar in terms of node purity increase and
+\(p\)-value, the former is markedly
+better if we look at the increase in MSE. Interestingly,
+nox
and indus
are quite good when it comes to
+the two measures reflected on the axes, but are not significant
+according to our \(p\)-value, which is
+a derivative of the number of nodes that use a variable for
+splitting.
Generally, the multi-way importance plot offers a wide variety of possibilities so it can be hard to select the most informative one. One idea of overcoming this obstacle is to first explore relations between different importance measures to then select three that least agree with each other and use them in the multi-way importance plot to select top variables. The first is easily done by plotting selected importance measures pairwise against each other using plot_importance_ggpairs
as below. One could of course include all seven measures in the plot but by default \(p\)-value and the number of trees are excluded as both carry similar information as the number of nodes.
Generally, the multi-way importance plot offers a wide variety of
+possibilities so it can be hard to select the most informative one. One
+idea of overcoming this obstacle is to first explore relations between
+different importance measures to then select three that least agree with
+each other and use them in the multi-way importance plot to select top
+variables. The first is easily done by plotting selected importance
+measures pairwise against each other using
+plot_importance_ggpairs
as below. One could of course
+include all seven measures in the plot but by default \(p\)-value and the number of trees are
+excluded as both carry similar information as the number of nodes.
-# plot_importance_ggpairs(forest) # gives the same result as below but takes longer
-plot_importance_ggpairs(importance_frame)
# plot_importance_ggpairs(forest) # gives the same result as below but takes longer
+plot_importance_ggpairs(importance_frame)
We can see that all depicted measures are highly correlated (of course the correlation of any measure with mean minimal depth is negative as the latter is lowest for best variables), but some less than others. Moreover, regardless of which measures we compare, there always seem to be two points that stand out and these most likely correspond to lstat
and rm
(to now for sure we could just examine the importance_frame
).
We can see that all depicted measures are highly correlated (of
+course the correlation of any measure with mean minimal depth is
+negative as the latter is lowest for best variables), but some less than
+others. Moreover, regardless of which measures we compare, there always
+seem to be two points that stand out and these most likely correspond to
+lstat
and rm
(to now for sure we could just
+examine the importance_frame
).
In addition to scatter plots and correlation coefficients, the ggpairs plot also depicts density estimate for each importance measure – all of which are in this case very skewed. An attempt to eliminate this feature by plotting rankings instead of raw measures is implemented in the function plot_importance_rankings
that also includes the fitted LOESS curve in each plot.
In addition to scatter plots and correlation coefficients, the
+ggpairs plot also depicts density estimate for each importance measure –
+all of which are in this case very skewed. An attempt to eliminate this
+feature by plotting rankings instead of raw measures is implemented in
+the function plot_importance_rankings
that also includes
+the fitted LOESS curve in each plot.
-# plot_importance_rankings(forest) # gives the same result as below but takes longer
-plot_importance_rankings(importance_frame)
# plot_importance_rankings(forest) # gives the same result as below but takes longer
+plot_importance_rankings(importance_frame)
The above density estimates show that skewness was eliminated for all of our importance measures (this is not always the case, e.g., when ties in rankings are frequent, and this is likely for discrete importance measures such as times_a_root
, then the distribution of the ranking will also be skewed).
When comparing the rankings in the above plot we can see that two pairs of measures almost exactly agree in their rankings of variables: mean_min_depth
vs. mse_increase
and mse_increase
vs. node_purity_increase
. In applications where there are many variables, the LOESS curve may be the main takeaway from this plot (if points fill in the whole plotting area and this is likely if the distributions of measures are close to uniform).
The above density estimates show that skewness was eliminated for all
+of our importance measures (this is not always the case, e.g., when ties
+in rankings are frequent, and this is likely for discrete importance
+measures such as times_a_root
, then the distribution of the
+ranking will also be skewed).
When comparing the rankings in the above plot we can see that two
+pairs of measures almost exactly agree in their rankings of variables:
+mean_min_depth
vs. mse_increase
and
+mse_increase
vs. node_purity_increase
. In
+applications where there are many variables, the LOESS curve may be the
+main takeaway from this plot (if points fill in the whole plotting area
+and this is likely if the distributions of measures are close to
+uniform).
After selecting a set of most important variables we can investigate interactions with respect to them, i.e. splits appearing in maximal subtrees with respect to one of the variables selected. To extract the names of 5 most important variables according to both the mean minimal depth and number of trees in which a variable appeared, we pass our importance_frame
to the function important_variables
as follows:
After selecting a set of most important variables we can investigate
+interactions with respect to them, i.e. splits appearing in maximal
+subtrees with respect to one of the variables selected. To extract the
+names of 5 most important variables according to both the mean minimal
+depth and number of trees in which a variable appeared, we pass our
+importance_frame
to the function
+important_variables
as follows:
-# (vars <- important_variables(forest, k = 5, measures = c("mean_min_depth", "no_of_trees"))) # gives the same result as below but takes longer
-(vars <- important_variables(importance_frame, k = 5, measures = c("mean_min_depth", "no_of_trees")))
## [1] "lstat" "rm" "crim" "nox" "dis"
-We pass the result together with or forest to the min_depth_interactions
function to obtain a data frame containing information on mean conditional minimal depth of variables with respect to each element of vars
(missing values are filled analogously as for unconditional minimal depth, in one of three ways specified by mean_sample
). If we would not specify the vars
argument then the vector of conditioning variables would be by default obtained using important_variables(measure_importance(forest))
.
# (vars <- important_variables(forest, k = 5, measures = c("mean_min_depth", "no_of_trees"))) # gives the same result as below but takes longer
+(vars <- important_variables(importance_frame, k = 5, measures = c("mean_min_depth", "no_of_trees")))
## [1] "lstat" "rm" "crim" "nox" "dis"
+We pass the result together with or forest to the
+min_depth_interactions
function to obtain a data frame
+containing information on mean conditional minimal depth of variables
+with respect to each element of vars
(missing values are
+filled analogously as for unconditional minimal depth, in one of three
+ways specified by mean_sample
). If we would not specify the
+vars
argument then the vector of conditioning variables
+would be by default obtained using
+important_variables(measure_importance(forest))
.
-# interactions_frame <- min_depth_interactions(forest, vars)
-# save(interactions_frame, file = "interactions_frame.rda")
-load("interactions_frame.rda")
-head(interactions_frame[order(interactions_frame$occurrences, decreasing = TRUE), ])
## variable root_variable mean_min_depth occurrences interaction
-## 53 rm lstat 1.179381 485 lstat:rm
-## 18 crim lstat 1.934738 478 lstat:crim
-## 3 age lstat 2.388948 475 lstat:age
-## 23 dis lstat 1.786887 475 lstat:dis
-## 33 lstat lstat 1.584338 474 lstat:lstat
-## 8 black lstat 2.870078 468 lstat:black
-## uncond_mean_min_depth
-## 53 1.346
-## 18 2.386
-## 3 3.308
-## 23 2.600
-## 33 1.288
-## 8 3.512
-Then, we pass our interactions_frame
to the plotting function plot_min_depth_interactions
and obtain the following:
# interactions_frame <- min_depth_interactions(forest, vars)
+# save(interactions_frame, file = "interactions_frame.rda")
+load("interactions_frame.rda")
+head(interactions_frame[order(interactions_frame$occurrences, decreasing = TRUE), ])
## variable root_variable mean_min_depth occurrences interaction
+## 53 rm lstat 1.179381 485 lstat:rm
+## 18 crim lstat 1.934738 478 lstat:crim
+## 3 age lstat 2.388948 475 lstat:age
+## 23 dis lstat 1.786887 475 lstat:dis
+## 33 lstat lstat 1.584338 474 lstat:lstat
+## 8 black lstat 2.870078 468 lstat:black
+## uncond_mean_min_depth
+## 53 1.346
+## 18 2.386
+## 3 3.308
+## 23 2.600
+## 33 1.288
+## 8 3.512
+Then, we pass our interactions_frame
to the plotting
+function plot_min_depth_interactions
and obtain the
+following:
-# plot_min_depth_interactions(forest) # calculates the interactions_frame for default settings so may give different results than the function below depending on our settings and takes more time
-plot_min_depth_interactions(interactions_frame)
# plot_min_depth_interactions(forest) # calculates the interactions_frame for default settings so may give different results than the function below depending on our settings and takes more time
+plot_min_depth_interactions(interactions_frame)
Note that the interactions are ordered by decreasing number of occurrences – the most frequent one, lstat:rm
, is also the one with minimal mean conditional minimal depth. Remarkably, the unconditional mean minimal depth of rm
in the forest is almost equal to its mean minimal depth across maximal subtrees with lstat
as the root variable.
Generally, the plot contains much information and can be interpreted in many ways but always bear in mind the method used for calculating the conditional (mean_sample
parameter) and unconditional (uncond_mean_sample
parameter) mean minimal depth. Using the default "top_trees"
penalizes interactions that occur less frequently than the most frequent one. Of course, one can switch between "all_trees"
, "top_trees"
and "relevant_trees"
for calculating the mean of both the conditional and unconditional minimal depth but each of them has its drawbacks and we favour using "top_trees"
(the default). However, as plot_min_depth_interactions
plots interactions by decreasing frequency the major drawback of calculating the mean only for relevant variables vanishes as interactions appearing for example only once but with conditional depth 0 will not be included in the plot anyway. Thus, we repeat the computation of means using "relevant_trees"
and get the following result:
Note that the interactions are ordered by decreasing number of
+occurrences – the most frequent one, lstat:rm
, is also the
+one with minimal mean conditional minimal depth. Remarkably, the
+unconditional mean minimal depth of rm
in the forest is
+almost equal to its mean minimal depth across maximal subtrees with
+lstat
as the root variable.
Generally, the plot contains much information and can be interpreted
+in many ways but always bear in mind the method used for calculating the
+conditional (mean_sample
parameter) and unconditional
+(uncond_mean_sample
parameter) mean minimal depth. Using
+the default "top_trees"
penalizes interactions that occur
+less frequently than the most frequent one. Of course, one can switch
+between "all_trees"
, "top_trees"
and
+"relevant_trees"
for calculating the mean of both the
+conditional and unconditional minimal depth but each of them has its
+drawbacks and we favour using "top_trees"
(the default).
+However, as plot_min_depth_interactions
plots interactions
+by decreasing frequency the major drawback of calculating the mean only
+for relevant variables vanishes as interactions appearing for example
+only once but with conditional depth 0 will not be included in the plot
+anyway. Thus, we repeat the computation of means using
+"relevant_trees"
and get the following result:
-# interactions_frame <- min_depth_interactions(forest, vars, mean_sample = "relevant_trees", uncond_mean_sample = "relevant_trees")
-# save(interactions_frame, file = "interactions_frame_relevant.rda")
-load("interactions_frame_relevant.rda")
-plot_min_depth_interactions(interactions_frame)
# interactions_frame <- min_depth_interactions(forest, vars, mean_sample = "relevant_trees", uncond_mean_sample = "relevant_trees")
+# save(interactions_frame, file = "interactions_frame_relevant.rda")
+load("interactions_frame_relevant.rda")
+plot_min_depth_interactions(interactions_frame)
Comparing this plot with the previous one we see that removing penalization of missing values lowers the mean conditional minimal depth of all interactions except the most frequent one. Now, in addition to the frequent ones, some of the less frequent like rm:tax
stand out.
Comparing this plot with the previous one we see that removing
+penalization of missing values lowers the mean conditional minimal depth
+of all interactions except the most frequent one. Now, in addition to
+the frequent ones, some of the less frequent like rm:tax
+stand out.
To further investigate the most frequent interaction lstat:rm
we use the function plot_predict_interaction
to plot the prediction of our forest on a grid of values for the components of each interaction. The function requires the forest, training data, variable to use on \(x\) and \(y\)-axis, respectively. In addition, one can also decrease the number of points in both dimensions of the grid from the default of 100 in case of insufficient memory using the parameter grid
.
To further investigate the most frequent interaction
+lstat:rm
we use the function
+plot_predict_interaction
to plot the prediction of our
+forest on a grid of values for the components of each interaction. The
+function requires the forest, training data, variable to use on \(x\) and \(y\)-axis, respectively. In addition, one
+can also decrease the number of points in both dimensions of the grid
+from the default of 100 in case of insufficient memory using the
+parameter grid
.
-plot_predict_interaction(forest, Boston, "rm", "lstat")
plot_predict_interaction(forest, Boston, "rm", "lstat")
In the above plot we can clearly see the effect of interaction: the predicted median price is highest when lstat
is low and rm
is high and low when the reverse is true. To further investigate the effect of interactions we could plot other frequent ones on a grid.
In the above plot we can clearly see the effect of interaction: the
+predicted median price is highest when lstat
is low and
+rm
is high and low when the reverse is true. To further
+investigate the effect of interactions we could plot other frequent ones
+on a grid.
The explain_forest()
function is the flagship function of the randomForestExplainer
package, as it takes your random forest and produces a html report that summarizes all basic results obtained for the forest with the new package. Below, we show how to run this function with default settings (we only supply the forest, training data, set interactions = TRUE
contrary to the default to show full functionality and decrease the grid for prediction plots in our most computationally-intense examples) for our data set.
The explain_forest()
function is the flagship function
+of the randomForestExplainer
package, as it takes your
+random forest and produces a html report that summarizes all basic
+results obtained for the forest with the new package. Below, we show how
+to run this function with default settings (we only supply the forest,
+training data, set interactions = TRUE
contrary to the
+default to show full functionality and decrease the grid for prediction
+plots in our most computationally-intense examples) for our data
+set.
-explain_forest(forest, interactions = TRUE, data = Boston)
To see the resulting HTML document click here: Boston forest summary
-For additional examples see: initial vignette.
+explain_forest(forest, interactions = TRUE, data = Boston)
To see the resulting HTML document click here: Boston +forest summary
+For additional examples see: initial +vignette.
Developed by Aleksandra Paluszynska, Przemyslaw Biecek, Yue Jiang.
+ +Developed by Aleksandra Paluszynska, Przemyslaw Biecek, Yue Jiang.
0RY$bv>gx%Q`!@gEs>hO}~)!Bel%$J2|27=%x4l-O6b{^6Awi-?wUA
z^U*sodyO+E6O--QqmsO<79D+yns_w9O>BBxmUqfK=8gh+B(?bKmyI0yzU=~P8LXH>
zqN_|a^eKC*`o^74g?IPvQz{zf?#T{Q6Qd?#zD>$8A|JG0!HVyALVr?OBEO-I9IX
z+=5gztz;B6S1VUGbua3+o s>s~Nz
z1ZNs@OGu;_-hyWed{h*WAb|-0dYk|3{Im