diff --git a/DESCRIPTION b/DESCRIPTION index a61da91..16763e2 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: kernelshap Title: Kernel SHAP -Version: 0.7.0 +Version: 0.7.1 Authors@R: c( person("Michael", "Mayer", , "mayermichael79@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0009-0007-2540-9629")), diff --git a/README.md b/README.md index 2552c5d..3f9c1b9 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ X <- diamonds[sample(nrow(diamonds), 1000), xvars] # from X is used bg_X <- diamonds[sample(nrow(diamonds), 200), ] -# 3) Crunch SHAP values for all 1000 rows of X (54 seconds) +# 3) Crunch SHAP values for all 1000 rows of X (22 seconds) # Note: Since the number of features is small, we use permshap() system.time( ps <- permshap(fit, X, bg_X = bg_X) @@ -137,8 +137,10 @@ plan(multisession, workers = 4) # Windows fit <- gam(log_price ~ s(log_carat) + clarity * color + cut, data = diamonds) -system.time( # 9 seconds in parallel - ps <- permshap(fit, X, parallel = TRUE, parallel_args = list(.packages = "mgcv")) +system.time( # 4 seconds in parallel + ps <- permshap( + fit, X, bg_X = bg_X, parallel = TRUE, parallel_args = list(.packages = "mgcv") + ) ) ps @@ -148,7 +150,7 @@ ps # [2,] -0.51546 -0.1174766 0.11122775 0.030243973 # Because there are no interactions of order above 2, Kernel SHAP gives the same: -system.time( # 27 s non-parallel +system.time( # 13 s non-parallel ks <- kernelshap(fit, X, bg_X = bg_X) ) all.equal(ps$S, ks$S) @@ -202,9 +204,9 @@ nn |> ) pred_fun <- function(mod, X) - predict(mod, data.matrix(X), batch_size = 1e4, verbose = FALSE) + predict(mod, data.matrix(X), batch_size = 1e4, verbose = FALSE, workers = 4) -system.time( # 60 s +system.time( # 50 s ps <- permshap(nn, X, bg_X = bg_X, pred_fun = pred_fun) ) @@ -284,7 +286,7 @@ iris_wf <- workflow() |> fit <- iris_wf |> fit(iris) -system.time( # 4s +system.time( # 3s ps <- permshap(fit, iris[-5], type = "prob") ) ps diff --git a/backlog/compare_with_python.R b/backlog/compare_with_python.R index ee449b6..f4ae07c 100644 --- a/backlog/compare_with_python.R +++ b/backlog/compare_with_python.R @@ -14,9 +14,9 @@ bg_X <- diamonds[seq(1, nrow(diamonds), 450), ] # Subset of 1018 diamonds to explain X_small <- diamonds[seq(1, nrow(diamonds), 53), c("carat", ord)] -# Exact KernelSHAP (5s) +# Exact KernelSHAP (2s) system.time( - ks <- kernelshap(fit, X_small, bg_X = bg_X) + ks <- kernelshap(fit, X_small, bg_X = bg_X) ) ks @@ -25,9 +25,9 @@ ks # [1,] -2.050074 -0.28048747 0.1281222 0.01587382 # [2,] -2.085838 0.04050415 0.1283010 0.03731644 -# Pure sampling version takes a bit longer (12 seconds) +# Pure sampling version takes a bit longer (7 seconds) system.time( - ks2 <- kernelshap(fit, X_small, bg_X = bg_X, exact = FALSE, hybrid_degree = 0) + ks2 <- kernelshap(fit, X_small, bg_X = bg_X, exact = FALSE, hybrid_degree = 0) ) ks2 @@ -36,18 +36,6 @@ ks2 # [1,] -2.050074 -0.28048747 0.1281222 0.01587382 # [2,] -2.085838 0.04050415 0.1283010 0.03731644 -# Using parallel backend -library("doFuture") - -registerDoFuture() -plan(multisession, workers = 2) # Windows -# plan(multicore, workers = 2) # Linux, macOS, Solaris - -# 3 seconds -system.time( - ks3 <- kernelshap(fit, X_small, bg_X = bg_X, parallel = TRUE) -) -ks3 library(shapviz) @@ -58,18 +46,17 @@ sv_dependence(sv, "carat") # More features (but non-sensical model) # Fit model fit <- lm( - log(price) ~ log(carat) * (clarity + color + cut) + x + y + z + table + depth, + log(price) ~ log(carat) * (clarity + color + cut) + x + y + z + table + depth, data = diamonds ) # Subset of 1018 diamonds to explain X_small <- diamonds[seq(1, nrow(diamonds), 53), setdiff(names(diamonds), "price")] -# Exact KernelSHAP on X_small, using X_small as background data -# (58/67(?) seconds for exact, 25/18 for hybrid deg 2, 16/9 for hybrid deg 1, -# 26/17 for pure sampling; second number with 2 parallel sessions on Windows) +# Exact KernelSHAP on X_small, using X_small as background data +# (39s for exact, 15s for hybrid deg 2, 8s for hybrid deg 1, 16s for sampling) system.time( - ks <- kernelshap(fit, X_small, bg_X = bg_X) + ks <- kernelshap(fit, X_small, bg_X = bg_X) ) ks @@ -98,7 +85,7 @@ X = diamonds[x].to_numpy() # Fit model with interactions and dummy variables fit = ols( - "np.log(price) ~ np.log(carat) * (C(clarity) + C(cut) + C(color))", # + x + y + z + table + depth", + "np.log(price) ~ np.log(carat) * (C(clarity) + C(cut) + C(color))", # + x + y + z + table + depth", data=diamonds ).fit() @@ -110,7 +97,7 @@ X_small = X[0:len(X):53] # Calculate KernelSHAP values ks = KernelExplainer( - model=lambda X: fit.predict(pd.DataFrame(X, columns=x)), + model=lambda X: fit.predict(pd.DataFrame(X, columns=x)), data = bg_X ) sv = ks.shap_values(X_small) # 11 minutes @@ -127,4 +114,4 @@ sv[0:2] # -1.72078182e-01, 1.33027467e-03, -6.44569296e-03], # [-1.87670887e+00, 3.93291219e-02, 1.26654599e-01, # 3.85695742e-02, -4.87177593e-04, -4.20263565e-04, -# -1.73988040e-01, 1.39779179e-03, -6.56062359e-03]]) \ No newline at end of file +# -1.73988040e-01, 1.39779179e-03, -6.56062359e-03]]) diff --git a/packaging.R b/packaging.R index 8dbdcae..df37cc1 100644 --- a/packaging.R +++ b/packaging.R @@ -15,13 +15,13 @@ library(usethis) use_description( fields = list( Title = "Kernel SHAP", - Version = "0.7.0", + Version = "0.7.1", Description = "Efficient implementation of Kernel SHAP, see Lundberg and Lee (2017), and Covert and Lee (2021) . Furthermore, for up to 14 features, exact permutation SHAP values can be calculated. The package plays well together with meta-learning packages like 'tidymodels', 'caret' or 'mlr3'. Visualizations can be done using the R package 'shapviz'.", - `Authors@R` = + `Authors@R` = "c(person('Michael', family='Mayer', role=c('aut', 'cre'), email='mayermichael79@gmail.com', comment=c(ORCID='0009-0007-2540-9629')), person('David', family='Watson', role='aut', email='david.s.watson11@gmail.com', comment=c(ORCID='0000-0001-9632-2159')), person('Przemyslaw', family='Biecek', email='przemyslaw.biecek@gmail.com', role='ctb', comment=c(ORCID='0000-0001-8423-1823')) @@ -98,7 +98,7 @@ install(upgrade = FALSE) if (FALSE) { check_win_devel() check_rhub() - + # Takes long revdepcheck::revdep_check(num_workers = 4L, bioc = FALSE)