Skip to content

Commit

Permalink
Merge pull request #145 from ModelOriented/update-readme
Browse files Browse the repository at this point in the history
Update time benchmarks
  • Loading branch information
mayer79 authored Sep 3, 2024
2 parents 29f0de1 + 22aee87 commit 5f65a24
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 35 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: kernelshap
Title: Kernel SHAP
Version: 0.7.0
Version: 0.7.1
Authors@R: c(
person("Michael", "Mayer", , "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0009-0007-2540-9629")),
Expand Down
16 changes: 9 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ X <- diamonds[sample(nrow(diamonds), 1000), xvars]
# from X is used
bg_X <- diamonds[sample(nrow(diamonds), 200), ]

# 3) Crunch SHAP values for all 1000 rows of X (54 seconds)
# 3) Crunch SHAP values for all 1000 rows of X (22 seconds)
# Note: Since the number of features is small, we use permshap()
system.time(
ps <- permshap(fit, X, bg_X = bg_X)
Expand Down Expand Up @@ -137,8 +137,10 @@ plan(multisession, workers = 4) # Windows

fit <- gam(log_price ~ s(log_carat) + clarity * color + cut, data = diamonds)

system.time( # 9 seconds in parallel
ps <- permshap(fit, X, parallel = TRUE, parallel_args = list(.packages = "mgcv"))
system.time( # 4 seconds in parallel
ps <- permshap(
fit, X, bg_X = bg_X, parallel = TRUE, parallel_args = list(.packages = "mgcv")
)
)
ps

Expand All @@ -148,7 +150,7 @@ ps
# [2,] -0.51546 -0.1174766 0.11122775 0.030243973

# Because there are no interactions of order above 2, Kernel SHAP gives the same:
system.time( # 27 s non-parallel
system.time( # 13 s non-parallel
ks <- kernelshap(fit, X, bg_X = bg_X)
)
all.equal(ps$S, ks$S)
Expand Down Expand Up @@ -202,9 +204,9 @@ nn |>
)

pred_fun <- function(mod, X)
predict(mod, data.matrix(X), batch_size = 1e4, verbose = FALSE)
predict(mod, data.matrix(X), batch_size = 1e4, verbose = FALSE, workers = 4)

system.time( # 60 s
system.time( # 50 s
ps <- permshap(nn, X, bg_X = bg_X, pred_fun = pred_fun)
)

Expand Down Expand Up @@ -284,7 +286,7 @@ iris_wf <- workflow() |>
fit <- iris_wf |>
fit(iris)

system.time( # 4s
system.time( # 3s
ps <- permshap(fit, iris[-5], type = "prob")
)
ps
Expand Down
35 changes: 11 additions & 24 deletions backlog/compare_with_python.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ bg_X <- diamonds[seq(1, nrow(diamonds), 450), ]
# Subset of 1018 diamonds to explain
X_small <- diamonds[seq(1, nrow(diamonds), 53), c("carat", ord)]

# Exact KernelSHAP (5s)
# Exact KernelSHAP (2s)
system.time(
ks <- kernelshap(fit, X_small, bg_X = bg_X)
ks <- kernelshap(fit, X_small, bg_X = bg_X)
)
ks

Expand All @@ -25,9 +25,9 @@ ks
# [1,] -2.050074 -0.28048747 0.1281222 0.01587382
# [2,] -2.085838 0.04050415 0.1283010 0.03731644

# Pure sampling version takes a bit longer (12 seconds)
# Pure sampling version takes a bit longer (7 seconds)
system.time(
ks2 <- kernelshap(fit, X_small, bg_X = bg_X, exact = FALSE, hybrid_degree = 0)
ks2 <- kernelshap(fit, X_small, bg_X = bg_X, exact = FALSE, hybrid_degree = 0)
)
ks2

Expand All @@ -36,18 +36,6 @@ ks2
# [1,] -2.050074 -0.28048747 0.1281222 0.01587382
# [2,] -2.085838 0.04050415 0.1283010 0.03731644

# Using parallel backend
library("doFuture")

registerDoFuture()
plan(multisession, workers = 2) # Windows
# plan(multicore, workers = 2) # Linux, macOS, Solaris

# 3 seconds
system.time(
ks3 <- kernelshap(fit, X_small, bg_X = bg_X, parallel = TRUE)
)
ks3

library(shapviz)

Expand All @@ -58,18 +46,17 @@ sv_dependence(sv, "carat")
# More features (but non-sensical model)
# Fit model
fit <- lm(
log(price) ~ log(carat) * (clarity + color + cut) + x + y + z + table + depth,
log(price) ~ log(carat) * (clarity + color + cut) + x + y + z + table + depth,
data = diamonds
)

# Subset of 1018 diamonds to explain
X_small <- diamonds[seq(1, nrow(diamonds), 53), setdiff(names(diamonds), "price")]

# Exact KernelSHAP on X_small, using X_small as background data
# (58/67(?) seconds for exact, 25/18 for hybrid deg 2, 16/9 for hybrid deg 1,
# 26/17 for pure sampling; second number with 2 parallel sessions on Windows)
# Exact KernelSHAP on X_small, using X_small as background data
# (39s for exact, 15s for hybrid deg 2, 8s for hybrid deg 1, 16s for sampling)
system.time(
ks <- kernelshap(fit, X_small, bg_X = bg_X)
ks <- kernelshap(fit, X_small, bg_X = bg_X)
)
ks

Expand Down Expand Up @@ -98,7 +85,7 @@ X = diamonds[x].to_numpy()

# Fit model with interactions and dummy variables
fit = ols(
"np.log(price) ~ np.log(carat) * (C(clarity) + C(cut) + C(color))", # + x + y + z + table + depth",
"np.log(price) ~ np.log(carat) * (C(clarity) + C(cut) + C(color))", # + x + y + z + table + depth",
data=diamonds
).fit()

Expand All @@ -110,7 +97,7 @@ X_small = X[0:len(X):53]

# Calculate KernelSHAP values
ks = KernelExplainer(
model=lambda X: fit.predict(pd.DataFrame(X, columns=x)),
model=lambda X: fit.predict(pd.DataFrame(X, columns=x)),
data = bg_X
)
sv = ks.shap_values(X_small) # 11 minutes
Expand All @@ -127,4 +114,4 @@ sv[0:2]
# -1.72078182e-01, 1.33027467e-03, -6.44569296e-03],
# [-1.87670887e+00, 3.93291219e-02, 1.26654599e-01,
# 3.85695742e-02, -4.87177593e-04, -4.20263565e-04,
# -1.73988040e-01, 1.39779179e-03, -6.56062359e-03]])
# -1.73988040e-01, 1.39779179e-03, -6.56062359e-03]])
6 changes: 3 additions & 3 deletions packaging.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ library(usethis)
use_description(
fields = list(
Title = "Kernel SHAP",
Version = "0.7.0",
Version = "0.7.1",
Description = "Efficient implementation of Kernel SHAP, see Lundberg and Lee (2017),
and Covert and Lee (2021) <http://proceedings.mlr.press/v130/covert21a>.
Furthermore, for up to 14 features, exact permutation SHAP values can be calculated.
The package plays well together with meta-learning packages like 'tidymodels', 'caret' or 'mlr3'.
Visualizations can be done using the R package 'shapviz'.",
`Authors@R` =
`Authors@R` =
"c(person('Michael', family='Mayer', role=c('aut', 'cre'), email='[email protected]', comment=c(ORCID='0009-0007-2540-9629')),
person('David', family='Watson', role='aut', email='[email protected]', comment=c(ORCID='0000-0001-9632-2159')),
person('Przemyslaw', family='Biecek', email='[email protected]', role='ctb', comment=c(ORCID='0000-0001-8423-1823'))
Expand Down Expand Up @@ -98,7 +98,7 @@ install(upgrade = FALSE)
if (FALSE) {
check_win_devel()
check_rhub()

# Takes long
revdepcheck::revdep_check(num_workers = 4L, bioc = FALSE)

Expand Down

0 comments on commit 5f65a24

Please sign in to comment.