Skip to content

Commit

Permalink
fix problematic argument survival in ranger models
Browse files Browse the repository at this point in the history
  • Loading branch information
mayer79 committed Aug 3, 2024
1 parent a1ed340 commit 74be45d
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 23 deletions.
8 changes: 3 additions & 5 deletions R/kernelshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -359,12 +359,11 @@ kernelshap.ranger <- function(
survival = c("chf", "prob"),
...
) {
survival <- match.arg(survival)


if (is.null(pred_fun)) {
pred_fun <- pred_ranger
pred_fun <- create_ranger_pred_fun(object$treetype, survival = match.arg(survival))
}

kernelshap.default(
object = object,
X = X,
Expand All @@ -381,7 +380,6 @@ kernelshap.ranger <- function(
parallel = parallel,
parallel_args = parallel_args,
verbose = verbose,
survival = survival,
...
)
}
Expand Down
8 changes: 3 additions & 5 deletions R/permshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,11 @@ permshap.ranger <- function(
survival = c("chf", "prob"),
...
) {
survival <- match.arg(survival)


if (is.null(pred_fun)) {
pred_fun <- pred_ranger
pred_fun <- create_ranger_pred_fun(object$treetype, survival = match.arg(survival))
}

permshap.default(
object = object,
X = X,
Expand All @@ -188,7 +187,6 @@ permshap.ranger <- function(
parallel = parallel,
parallel_args = parallel_args,
verbose = verbose,
survival = survival,
...
)
}
Expand Down
32 changes: 19 additions & 13 deletions R/pred_fun.R
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
#' Predict Function for Ranger
#'
#' Internal function that prepares the predictions of different types of ranger models,
#' including survival models.
#' Returns prediction function for different modes of ranger.
#'
#' @noRd
#' @keywords internal
#' @param model Fitted ranger model.
#' @param newdata Data to predict on.
#' @param treetype The value of `fit$treetype` in a fitted ranger model.
#' @param survival Cumulative hazards "chf" (default) or probabilities "prob" per time.
#' @param ... Additional arguments passed to ranger's predict function.
#'
#' @returns A vector or matrix with predictions.
pred_ranger <- function(model, newdata, survival = c("chf", "prob"), ...) {
#' @returns A function with signature f(model, newdata, ...).
create_ranger_pred_fun <- function(treetype, survival = c("chf", "prob")) {
survival <- match.arg(survival)

pred <- stats::predict(model, newdata, ...)
if (treetype != "Survival") {
pred_fun <- function(model, newdata, ...) {
stats::predict(model, newdata, ...)$predictions
}
return(pred_fun)
}

if (survival == "prob") {
survival <- "survival"
}

if (model$treetype == "Survival") {
out <- if (survival == "chf") pred$chf else pred$survival
pred_fun <- function(model, newdata, ...) {
pred <- stats::predict(model, newdata, ...)
out <- pred[[survival]]
colnames(out) <- paste0("t", pred$unique.death.times)
} else {
out <- pred$predictions
return(out)
}
return(out)
return(pred_fun)
}

28 changes: 28 additions & 0 deletions backlog/test_ranger.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
library(ranger)
library(survival)
library(kernelshap)

set.seed(1)

fit <- ranger(Surv(time, status) ~ ., data = veteran, num.trees = 20)
fit2 <- ranger(time ~ . - status, data = veteran, num.trees = 20)
fit3 <- ranger(time ~ . - status, data = veteran, quantreg = TRUE, num.trees = 20)
fit4 <- ranger(status ~ . - time, data = veteran, probability = TRUE, num.trees = 20)

xvars <- setdiff(colnames(veteran), c("time", "status"))

kernelshap(fit, head(veteran), feature_names = xvars, bg_X = veteran)
permshap(fit, head(veteran), feature_names = xvars, bg_X = veteran)

kernelshap(fit, head(veteran), feature_names = xvars, bg_X = veteran, survival = "prob")
permshap(fit, head(veteran), feature_names = xvars, bg_X = veteran, survival = "prob")

kernelshap(fit2, head(veteran), feature_names = xvars, bg_X = veteran)
permshap(fit2, head(veteran), feature_names = xvars, bg_X = veteran)

kernelshap(fit3, head(veteran), feature_names = xvars, bg_X = veteran, type = "quantiles")
permshap(fit3, head(veteran), feature_names = xvars, bg_X = veteran, type = "quantiles")

kernelshap(fit4, head(veteran), feature_names = xvars, bg_X = veteran)
permshap(fit4, head(veteran), feature_names = xvars, bg_X = veteran)

0 comments on commit 74be45d

Please sign in to comment.