This vignette explains how to use {shapviz} with {Tidymodels}.
XGBoost and LightGBM are shipped with super-fast TreeSHAP algorithms. Thus, doing a SHAP analysis is quite different from the normal case.
A model fitted with Tidymodels has a predict() method
that produces a data.frame with predictions. Therefore, working with
model-agnostic SHAP (permutation SHAP or Kernel SHAP) is as easy as it
can get. But is takes a little bit of time.
library(tidymodels)
library(kernelshap)
library(shapviz)
set.seed(10)
splits <- diamonds |> 
  transform(
    log_price = log(price),
    log_carat = log(carat)
) |> 
  initial_split()
df_train <- training(splits)
dia_recipe <- df_train |>
  recipe(log_price ~ log_carat + color + clarity + cut)
rf <- rand_forest(mode = "regression") |>
  set_engine("ranger")
rf_wf <- workflow() |>
  add_recipe(dia_recipe) |>
  add_model(rf)
fit <- rf_wf |>
  fit(df_train)
  
# SHAP analysis
xvars <- c("log_carat", "color", "clarity", "cut")
X_explain <- df_train[1:1000, xvars]  # Use only feature columns
# 1.5 minutes on laptop
# Note: If you have more than p=8 features, use kernelshap() instead of permshap()
system.time(
  shap_values <- fit |> 
    permshap(X = X_explain) |> 
    shapviz()
)
# saveRDS(shap_values, file = "shap_values.rds")
# shap_values <- readRDS("shap_values.rds")
shap_values |> 
  sv_importance("bee")
shap_values |> 
  sv_dependence(xvars) 
When your Tidymodel is an XGBoost or LightGBM model, you will almost
always want to use their native TreeSHAP implementation. In this case,
you need to pass to shapviz() the fully prepared
explanation matrix X_pred and the underlying fit
engine.
We will show how to prepare the inputs for shapviz(),
namely
X_pred, the matrix passed to XGBoost’s
predict(),X, the dataframe used for visualizations
(to see original factor levels etc).Since XGBoost offers SHAP interactions, we additionally show how to integrate these into the analysis. of course, you don’t have to work with SHAP interactions, especially if your model has many predictors.
Remark: Don’t use 1:m transforms such as
One-Hot-Encodings. They are usually not necessary and make the workflow
more complicated. If you can’t avoid this, check the
collapse argument in shapviz().
library(tidymodels)
library(shapviz)
library(patchwork)
set.seed(10)
splits <- diamonds |> 
  transform(
    log_price = log(price),
    log_carat = log(carat)
) |> 
  initial_split()
df_train <- training(splits)
dia_recipe <- df_train |>
  recipe(log_price ~ log_carat + color + clarity + cut) |> 
  step_integer(all_ordered())
# Should be tuned in practice
xgb_model <- boost_tree(mode = "regression", learn_rate = 0.1, trees = 100) |>
  set_engine("xgboost")
xgb_wf <- workflow() |>
  add_recipe(dia_recipe) |>
  add_model(xgb_model)
fit <- xgb_wf |>
  fit(df_train)
# SHAP Analysis
df_explain <- df_train[1:1000, ]
X_pred <- bake(  # Goes to xgboost:::predict.xgb.Booster()
  prep(dia_recipe), 
  has_role("predictor"),
  new_data = df_explain,
  composition = "matrix"
)
stopifnot(colnames(X_pred) %in% colnames(df_explain))
shap_values <- extract_fit_engine(fit) |> 
  shapviz(X_pred = X_pred, X = df_explain, interactions = TRUE)
# SHAP importance
shap_values |> 
  sv_importance(show_numbers = TRUE) +
  ggtitle("SHAP importance")
# Absolute average SHAP interactions (off-diagonals already multiplied by 2)
shap_values |> 
  sv_interaction(kind = "no")
#            log_carat     clarity       color         cut
# log_carat 0.87400688 0.067567245 0.032599394 0.024273852
# clarity   0.06756720 0.143393109 0.028236784 0.004910905
# color     0.03259941 0.028236796 0.095656042 0.004804729
# cut       0.02427382 0.004910904 0.004804732 0.031114735
# Usual dependence plot
xvars <- c("log_carat", "color", "clarity", "cut")
shap_values |> 
  sv_dependence(xvars) &
  plot_annotation("SHAP dependence plots")  # patchwork magic
# SHAP interactions for carat
shap_values |> 
  sv_dependence("log_carat", color_var = xvars, interactions = TRUE) &
  plot_annotation("SHAP interactions for carat") 
Regarding SHAP analysis and Tidymodels, LightGBM is slightly different from XGBoost:
library(tidymodels)
library(bonsai)
library(shapviz)
set.seed(10)
splits <- diamonds |> 
  transform(
    log_price = log(price),
    log_carat = log(carat)
) |> 
  initial_split()
df_train <- training(splits)
dia_recipe <- df_train |>
  recipe(log_price ~ log_carat + color + clarity + cut) |> 
  step_integer(color, clarity)  # we keep cut a factor (for illustration only)
# Should be tuned in practice
lgb_model <- boost_tree(mode = "regression", learn_rate = 0.1, trees = 100) |>
  set_engine("lightgbm")
lgb_wf <- workflow() |>
  add_recipe(dia_recipe) |>
  add_model(lgb_model)
fit <- lgb_wf |>
  fit(df_train)
# SHAP analysis
df_explain <- df_train[1:1000, ]
X_pred <- bake(   # Goes to lightgbm:::predict.lgb.Booster()
  prep(dia_recipe), 
  has_role("predictor"),
  new_data = df_explain
) |> 
  bonsai:::prepare_df_lgbm()
  
head(X_pred, 2)
#       log_carat color clarity cut
# [1,]  0.3148107     5       5   3
# [2,] -0.5978370     2       3   4
stopifnot(colnames(X_pred) %in% colnames(df_explain))
shap_values <- extract_fit_engine(fit) |> 
  shapviz(X_pred = X_pred, X = df_explain)
shap_values |> 
  sv_importance(show_numbers = TRUE)
shap_values |> 
  sv_dependence(c("log_carat", "color", "clarity", "cut")) 
For probabilistic classification, the code is very similar to above regression examples.
shapviz() returns a list of “shapviz” objects (one per
class). Sometimes, you might want to analyze them together, or select an
individual class via $name_of_interesting_class or
[[.
Simply pass type = "prob" to
kernelshap::kernelshap() or
kernelshap::permshap():
library(tidymodels)
library(kernelshap)
library(shapviz)
library(patchwork)
set.seed(1)
iris_recipe <- iris |> 
  recipe(Species ~ .)
fit <- rand_forest(trees = 100) |>
  set_engine("ranger") |> 
  set_mode("classification")
  
iris_wf <- workflow() |>
  add_recipe(iris_recipe) |>
  add_model(fit)
fit <- iris_wf |>
  fit(iris)
# SHAP analysis
X_explain <- iris[-5]  # Feature columns of <=2000 rows from the training data
system.time(  # 2s
  shap_values <- permshap(fit, X_explain, type = "prob") |> 
    shapviz()
)
sv_importance(shap_values)
shap_values |> 
  sv_dependence("Sepal.Length") +
  plot_layout(ncol = 1) +
  plot_annotation("SHAP dependence of one variable for all classes")
# Use $ to extract SHAP values for one class
shap_setosa <- shap_values$.pred_setosa
shap_setosa |> 
  sv_dependence(colnames(X_explain)) +
  plot_annotation("SHAP dependence of all variables for one class")For XGBoost and LightGBM, we again want to use its native TreeSHAP implementation.
We can slightly adapt the code from the regression example:
library(tidymodels)
library(shapviz)
library(patchwork)
set.seed(1)
iris_recipe <- iris |>
  recipe(Species ~ .)
xgb_model <- boost_tree(learn_rate = 0.1, trees = 100) |>
  set_mode("classification") |> 
  set_engine("xgboost", verbose = -1)
xgb_wf <- workflow() |>
  add_recipe(iris_recipe) |>
  add_model(xgb_model)
fit <- xgb_wf |>
  fit(iris)
  
# SHAP analysis
df_explain <- iris  # Typically 1000 - 2000 rows from the training data
X_pred <- bake(  # goes to xgboost:::predict.xgb.Booster()
  prep(iris_recipe), 
  has_role("predictor"),
  new_data = df_explain,
  composition = "matrix"
)
stopifnot(colnames(X_pred) %in% colnames(df_explain))
shap_values <- extract_fit_engine(fit) |> 
  shapviz(X_pred = X_pred, X = df_explain) |> 
  setNames(levels(iris$Species))
shap_values |> 
  sv_importance()
shap_values |> 
  sv_dependence(v = "Sepal.Length", color_var = "Sepal.Width") +
  plot_layout(ncol = 1, guides = "collect")Let’s complete this vignette by running a binary LightGBM model.
library(tidymodels)
library(bonsai)
library(shapviz)
library(patchwork)
set.seed(1)
# Make factor with two levels
iris$sl_large <- factor(
  iris$Sepal.Length > median(iris$Sepal.Length), labels = c("no", "yes")
)
iris_recipe <- iris |>
  recipe(sl_large ~ Sepal.Width + Petal.Length + Petal.Width + Species) # |> 
  # step_integer(some ordinal factors)
  
lgb_model <- boost_tree(learn_rate = 0.1, trees = 100) |>
  set_mode("classification") |> 
  set_engine("lightgbm", verbose = -1)
lgb_wf <- workflow() |>
  add_recipe(iris_recipe) |>
  add_model(lgb_model)
fit <- lgb_wf |>
  fit(iris)
  
# SHAP analysis
df_explain <- iris  # Typically 1000 - 2000 rows from the training data
X_pred <- bake(
  prep(iris_recipe), 
  has_role("predictor"),
  new_data = df_explain
) |> 
  bonsai:::prepare_df_lgbm()
stopifnot(colnames(X_pred) %in% colnames(df_explain))
shap_values <- extract_fit_engine(fit) |> 
  shapviz(X_pred = X_pred, X = df_explain)
shap_values |> 
  sv_importance()
shap_values |> 
  sv_dependence("Species")