Let’s see an example for DALEX package for
classification models for the survival problem for Titanic dataset. Here
we are using a dataset titanic_imputed avaliable in the
DALEX package. Note that this data was copied from the
stablelearner package and changed for practicality.
library("DALEX")
head(titanic_imputed)
#> gender age class embarked fare sibsp parch survived
#> 1 male 42 3rd Southampton 7.11 0 0 0
#> 2 male 13 3rd Southampton 20.05 0 2 0
#> 3 male 16 3rd Southampton 20.05 1 1 0
#> 4 female 39 3rd Southampton 20.05 1 1 1
#> 5 female 16 3rd Southampton 7.13 0 0 1
#> 6 male 25 3rd Southampton 7.13 0 0 1
Ok, now it’s time to create a model. Let’s use the Random Forest model.
# prepare model
library("ranger")
model_titanic_rf <- ranger(survived ~ gender + age + class + embarked +
fare + sibsp + parch,
data = titanic_imputed, probability = TRUE)
model_titanic_rf
#> Ranger result
#>
#> Call:
#> ranger(survived ~ gender + age + class + embarked + fare + sibsp + parch, data = titanic_imputed, probability = TRUE)
#>
#> Type: Probability estimation
#> Number of trees: 500
#> Sample size: 2207
#> Number of independent variables: 7
#> Mtry: 2
#> Target node size: 10
#> Variable importance mode: none
#> Splitrule: gini
#> OOB prediction error (Brier s.): 0.1422968
The third step (it’s optional but useful) is to create a
DALEX explainer for random forest model.
library("DALEX")
explain_titanic_rf <- explain(model_titanic_rf,
data = titanic_imputed[,-8],
y = titanic_imputed[,8],
label = "Random Forest")
#> Preparation of a new explainer is initiated
#> -> model label : Random Forest
#> -> data : 2207 rows 7 cols
#> -> target variable : 2207 values
#> -> predict function : yhat.ranger will be used ( default )
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package ranger , ver. 0.14.1 , task classification ( default )
#> -> predicted values : numerical, min = 0.01164526 , mean = 0.3215481 , max = 0.9899436
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = -0.7923093 , mean = 0.0006086512 , max = 0.8905081
#> A new explainer has been created!
Use the feature_importance() explainer to present
importance of particular features. Note that
type = "difference" normalizes dropouts, and now they all
start in 0.
library("ingredients")
fi_rf <- feature_importance(explain_titanic_rf)
head(fi_rf)
#> variable mean_dropout_loss label
#> 1 _full_model_ 0.3408062 Random Forest
#> 2 parch 0.3520488 Random Forest
#> 3 sibsp 0.3520933 Random Forest
#> 4 embarked 0.3527842 Random Forest
#> 5 age 0.3760269 Random Forest
#> 6 fare 0.3848921 Random Forest
plot(fi_rf)
As we see the most important feature is gender. Next
three importnat features are class, age and
fare. Let’s see the link between model response and these
features.
Such univariate relation can be calculated with
partial_dependence().
Kids 5 years old and younger have much higher survival probability.
pp_age <- partial_dependence(explain_titanic_rf, variables = c("age", "fare"))
head(pp_age)
#> Top profiles :
#> _vname_ _label_ _x_ _yhat_ _ids_
#> 1 fare Random Forest 0.0000000 0.3630884 0
#> 2 age Random Forest 0.1666667 0.5347603 0
#> 3 age Random Forest 2.0000000 0.5536098 0
#> 4 age Random Forest 4.0000000 0.5595259 0
#> 5 fare Random Forest 6.1793080 0.3100674 0
#> 6 age Random Forest 7.0000000 0.5159751 0
plot(pp_age)
cp_age <- conditional_dependence(explain_titanic_rf, variables = c("age", "fare"))
plot(cp_age)
ap_age <- accumulated_dependence(explain_titanic_rf, variables = c("age", "fare"))
plot(ap_age)
Let’s see break down explanation for model predictions for 8 years old male from 1st class that embarked from port C.
First Ceteris Paribus Profiles for numerical variables
new_passanger <- data.frame(
class = factor("1st", levels = c("1st", "2nd", "3rd", "deck crew", "engineering crew", "restaurant staff", "victualling crew")),
gender = factor("male", levels = c("female", "male")),
age = 8,
sibsp = 0,
parch = 0,
fare = 72,
embarked = factor("Southampton", levels = c("Belfast", "Cherbourg", "Queenstown", "Southampton"))
)
sp_rf <- ceteris_paribus(explain_titanic_rf, new_passanger)
plot(sp_rf) +
show_observations(sp_rf)
And for selected categorical variables. Note, that sibsp is numerical but here is presented as a categorical variable.
plot(sp_rf,
variables = c("class", "embarked", "gender", "sibsp"),
variable_type = "categorical")
It looks like the most important feature for this passenger is
age and sex. After all his odds for survival
are higher than for the average passenger. Mainly because of the young
age and despite of being a male.
passangers <- select_sample(titanic, n = 100)
sp_rf <- ceteris_paribus(explain_titanic_rf, passangers)
clust_rf <- cluster_profiles(sp_rf, k = 3)
head(clust_rf)
#> Top profiles :
#> _vname_ _label_ _x_ _cluster_ _yhat_ _ids_
#> 1 fare Random Forest_1 0.0000000 1 0.2372045 0
#> 2 parch Random Forest_1 0.0000000 1 0.1658665 0
#> 3 sibsp Random Forest_1 0.0000000 1 0.1699181 0
#> 4 age Random Forest_1 0.1666667 1 0.4653162 0
#> 5 parch Random Forest_1 1.0000000 1 0.2539302 0
#> 6 sibsp Random Forest_1 1.0000000 1 0.1519697 0
plot(sp_rf, alpha = 0.1) +
show_aggregated_profiles(clust_rf, color = "_label_", size = 2)
sessionInfo()
#> R version 4.2.2 (2022-10-31)
#> Platform: aarch64-apple-darwin20 (64-bit)
#> Running under: macOS Monterey 12.5.1
#>
#> Matrix products: default
#> BLAS: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRblas.0.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib
#>
#> locale:
#> [1] C/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#>
#> attached base packages:
#> [1] stats graphics grDevices utils datasets methods base
#>
#> other attached packages:
#> [1] ggplot2_3.4.0 ranger_0.14.1 ingredients_2.3.0 DALEX_2.4.2
#>
#> loaded via a namespace (and not attached):
#> [1] Rcpp_1.0.9 highr_0.10 bslib_0.4.2 compiler_4.2.2
#> [5] pillar_1.8.1 jquerylib_0.1.4 tools_4.2.2 digest_0.6.31
#> [9] lattice_0.20-45 jsonlite_1.8.4 evaluate_0.19 lifecycle_1.0.3
#> [13] tibble_3.1.8 gtable_0.3.1 pkgconfig_2.0.3 rlang_1.0.6
#> [17] Matrix_1.5-1 cli_3.6.0 rstudioapi_0.14 yaml_2.3.6
#> [21] xfun_0.36 fastmap_1.1.0 withr_2.5.0 stringr_1.5.0
#> [25] knitr_1.41 vctrs_0.5.1 sass_0.4.4 grid_4.2.2
#> [29] glue_1.6.2 R6_2.5.1 fansi_1.0.3 rmarkdown_2.19
#> [33] farver_2.1.1 magrittr_2.0.3 scales_1.2.1 htmltools_0.5.4
#> [37] colorspace_2.0-3 labeling_0.4.2 utf8_1.2.2 stringi_1.7.12
#> [41] munsell_0.5.0 cachem_1.0.6