Skip to content

Commit

Permalink
candidate fix for ModelOriented/DALEX#229
Browse files Browse the repository at this point in the history
  • Loading branch information
pbiecek committed May 21, 2020
1 parent 543bc4f commit f3b279c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
18 changes: 11 additions & 7 deletions R/feature_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#' while "difference" returns \code{drop_loss - drop_loss_full_model}
#' @param N number of observations that should be sampled for calculation of variable importance.
#' If \code{NULL} then variable importance will be calculated on whole dataset (no sampling).
#' @param n_sample alias for \code{N} held for backwards compatibility. number of observations that should be sampled for calculation of variable importance.
#' @param B integer, number of permutation rounds to perform on each variable. By default it's \code{10}.
#' @param variables vector of variables. If \code{NULL} then variable importance will be tested for each variable from the \code{data} separately. By default \code{NULL}
#' @param variable_groups list of variables names vectors. This is for testing joint variable importance.
Expand Down Expand Up @@ -111,10 +112,11 @@ feature_importance.explainer <- function(x,
loss_function = DALEX::loss_root_mean_square,
...,
type = c("raw", "ratio", "difference"),
N = NULL,
n_sample = NULL,
B = 10,
variables = NULL,
variable_groups = NULL,
N = n_sample,
label = NULL) {
if (is.null(x$data)) stop("The feature_importance() function requires explainers created with specified 'data' parameter.")
if (is.null(x$y)) stop("The feature_importance() function requires explainers created with specified 'y' parameter.")
Expand All @@ -136,6 +138,7 @@ feature_importance.explainer <- function(x,
label = label,
type = type,
N = N,
n_sample = n_sample,
B = B,
variables = variables,
variable_groups = variable_groups,
Expand All @@ -153,15 +156,16 @@ feature_importance.default <- function(x,
...,
label = class(x)[1],
type = c("raw", "ratio", "difference"),
N = NULL,
n_sample = NULL,
B = 10,
variables = NULL,
N = n_sample,
variable_groups = NULL) {
# start: checks for arguments
if (is.null(N) & methods::hasArg("n_sample")) {
warning("n_sample is deprecated, please update ingredients and DALEX packages to use N instead")
N <- list(...)[["n_sample"]]
}
## if (is.null(N) & methods::hasArg("n_sample")) {
## warning("n_sample is deprecated, please update ingredients and DALEX packages to use N instead")
## N <- list(...)[["n_sample"]]
## }

if (!is.null(variable_groups)) {
if (!inherits(variable_groups, "list")) stop("variable_groups should be of class list")
Expand Down Expand Up @@ -210,7 +214,7 @@ feature_importance.default <- function(x,
# loss on the full model or when outcomes are permuted
loss_full <- loss_function(observed, predict_function(x, sampled_data))
loss_baseline <- loss_function(sample(observed), predict_function(x, sampled_data))
# loss upon dropping single variables (or single groups)
# loss upon dropping a single variable (or a single group)
loss_features <- sapply(variables, function(variables_set) {
ndf <- sampled_data
ndf[, variables_set] <- ndf[sample(1:nrow(ndf)), variables_set]
Expand Down
12 changes: 8 additions & 4 deletions man/feature_importance.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/test_variable_dropout.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ test_that("Output rf",{
})

test_that('deprecated n_sample', {
expect_warning(ingredients::feature_importance(explainer_glm, n_sample = 100))
expect_silent(ingredients::feature_importance(explainer_glm, n_sample = 100))
expect_silent(ingredients::feature_importance(explainer_glm, N = 100))
})

Expand Down

0 comments on commit f3b279c

Please sign in to comment.