Skip to content

Commit

Permalink
Greatly simplify mlr3 workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
mayer79 committed Dec 26, 2023
1 parent 70e0c68 commit eb92e92
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 89 deletions.
5 changes: 1 addition & 4 deletions .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,8 @@ jobs:
clean = FALSE,
install_path = file.path(Sys.getenv("RUNNER_TEMP"), "package"),
function_exclusions = c(
"kernelshap\\.Learner",
"kernelshap\\.ranger",
"permshap\\.Learner",
"permshap\\.ranger",
"mlr3_pred_fun"
"permshap\\.ranger"
)
)
shell: Rscript {0}
Expand Down
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: kernelshap
Title: Kernel SHAP
Version: 0.4.1
Version: 0.4.2
Authors@R: c(
person("Michael", "Mayer", , "[email protected]", role = c("aut", "cre")),
person("David", "Watson", , "[email protected]", role = "aut"),
Expand Down
2 changes: 0 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# Generated by roxygen2: do not edit by hand

S3method(kernelshap,Learner)
S3method(kernelshap,default)
S3method(kernelshap,ranger)
S3method(permshap,Learner)
S3method(permshap,default)
S3method(permshap,ranger)
S3method(print,kernelshap)
Expand Down
11 changes: 11 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# kernelshap 0.4.2

## API

- {mlr3}: Non-probabilistic classification now works.
- {mlr3}: For *probabilistic* classification, you now have to pass `predict_type = "prob"`.

## Documentation

- The README has received an {mlr3} and {caret} example.

# kernelshap 0.4.1

## Performance improvements
Expand Down
33 changes: 0 additions & 33 deletions R/kernelshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -355,36 +355,3 @@ kernelshap.ranger <- function(object, X, bg_X,
)
}

#' @describeIn kernelshap Kernel SHAP method for "mlr3" models, see Readme for an example.
#' @export
kernelshap.Learner <- function(object, X, bg_X,
pred_fun = NULL,
feature_names = colnames(X),
bg_w = NULL, exact = length(feature_names) <= 8L,
hybrid_degree = 1L + length(feature_names) %in% 4:16,
paired_sampling = TRUE,
m = 2L * length(feature_names) * (1L + 3L * (hybrid_degree == 0L)),
tol = 0.005, max_iter = 100L, parallel = FALSE,
parallel_args = NULL, verbose = TRUE, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
kernelshap.default(
object = object,
X = X,
bg_X = bg_X,
pred_fun = pred_fun,
feature_names = feature_names,
bg_w = bg_w,
exact = exact,
hybrid_degree = hybrid_degree,
paired_sampling = paired_sampling,
m = m,
tol = tol,
max_iter = max_iter,
parallel = parallel,
parallel_args = parallel_args,
verbose = verbose,
...
)
}
23 changes: 0 additions & 23 deletions R/permshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -168,26 +168,3 @@ permshap.ranger <- function(object, X, bg_X,
)
}

#' @describeIn permshap Permutation SHAP method for "mlr3" models, see Readme for an example.
#' @export
permshap.Learner <- function(object, X, bg_X,
pred_fun = NULL,
feature_names = colnames(X),
bg_w = NULL, parallel = FALSE, parallel_args = NULL,
verbose = TRUE, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
permshap.default(
object = object,
X = X,
bg_X = bg_X,
pred_fun = pred_fun,
feature_names = feature_names,
bg_w = bg_w,
parallel = parallel,
parallel_args = parallel_args,
verbose = verbose,
...
)
}
23 changes: 0 additions & 23 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -380,26 +380,3 @@ prep_w <- function(w, bg_n) {
if (!is.double(w)) as.double(w) else w
}

#' mlr3 Helper
#'
#' Returns the prediction function of a mlr3 Learner.
#'
#' @noRd
#' @keywords internal
#'
#' @param object Learner object.
#' @param X Dataframe like object.
#'
#' @returns A function.
mlr3_pred_fun <- function(object, X) {
if ("classif" %in% object$task_type) {
# Check if probabilities are available
test_pred <- object$predict_newdata(utils::head(X))
if ("prob" %in% test_pred$predict_types) {
return(function(m, X) m$predict_newdata(X)$prob)
} else {
stop("Set lrn(..., predict_type = 'prob') to allow for probabilistic classification.")
}
}
function(m, X) m$predict_newdata(X)$response
}
56 changes: 54 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,11 @@ ps_class <- permshap(fit_class, X = iris[, -5], bg_X = iris)

![](man/figures/README-prob-dep.svg)

### Tidymodels
### Meta-learners

Meta-learning packages like {tidymodels}, {caret} or {mlr3} are straight-forward to use. The following example additionally shows that the `...` argument of `permshap()` and `kernelshap()` is passed to `predict()`.
Meta-learning packages like {tidymodels}, {caret} or {mlr3} are straightforward to use. The following examples additionally shows that the `...` arguments of `permshap()` and `kernelshap()` are passed to `predict()`.

#### Tidymodels

```r
library(kernelshap)
Expand Down Expand Up @@ -285,6 +287,56 @@ $.pred_setosa
[2,] 0.02628333 0.001315556 0.3683833 0.2706111
```

#### caret

```r
library(kernelshap)
library(caret)

fit <- train(
Sepal.Length ~ .,
data = iris,
method = "lm",
tuneGrid = data.frame(intercept = TRUE),
trControl = trainControl(method = "none")
)

ps <- permshap(fit, iris[-1], bg_X = iris)
```

#### mlr3

```r
library(kernelshap)
library(mlr3)
library(mlr3learners)

set.seed(1)

task_classif <- TaskClassif$new(id = "1", backend = iris, target = "Species")
learner_classif <- lrn("classif.rpart", predict_type = "prob")
learner_classif$train(task_classif)

predict(learner_classif, head(iris)) # setosa setosa # Classes
predict(learner_classif, head(iris), predict_type = "prob") # Probs per class

x <- learner_classif$selected_features()

# For *probabilistic* classification, pass predict_type = "prob" to mlr3's predict()
ps <- permshap(
learner_classif, X = iris, bg_X = iris, feature_names = x, predict_type = "prob"
)
ps
# $setosa
# Petal.Length Petal.Width
# [1,] 0.6666667 0
# [2,] 0.6666667 0

# Non-probabilistic classification uses auto-OHE internally
ps <- permshap(learner_classif, X = iris, bg_X = iris, feature_names = x)
ps
```

## References

[1] Erik Štrumbelj and Igor Kononenko. Explaining prediction models and individual predictions with feature contributions. Knowledge and Information Systems 41, 2014.
Expand Down
2 changes: 1 addition & 1 deletion packaging.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ library(usethis)
use_description(
fields = list(
Title = "Kernel SHAP",
Version = "0.4.1",
Version = "0.4.2",
Description = "Efficient implementation of Kernel SHAP, see Lundberg and Lee (2017),
and Covert and Lee (2021) <http://proceedings.mlr.press/v130/covert21a>.
Furthermore, for up to 14 features, exact permutation SHAP values can be calculated.
Expand Down

0 comments on commit eb92e92

Please sign in to comment.