This vignette is a guide to policy_eval() and some of
the associated S3 methods. The purpose of policy_eval is to
estimate (evaluate) the value of a user-defined policy or a policy
learning algorithm. For details on the methodology, see the associated
paper (Nordland and Holst 2023).
We consider a fixed two-stage problem as a general setup and simulate
data using sim_two_stage() and create a
policy_data object using policy_data():
d <- sim_two_stage(n = 2e3, seed = 1)
pd <- policy_data(d,
action = c("A_1", "A_2"),
baseline = c("B", "BB"),
covariates = list(L = c("L_1", "L_2"),
C = c("C_1", "C_2")),
utility = c("U_1", "U_2", "U_3"))
pd## Policy data with n = 2000 observations and maximal K = 2 stages.
##
## action
## stage 0 1 n
## 1 1017 983 2000
## 2 819 1181 2000
##
## Baseline covariates: B, BB
## State covariates: L, C
## Average utility: 0.84
User-defined policies are created using policy_def(). In
this case we define a simple static policy always selecting action
'1':
As we want to apply the same policy function at both stages we set
reuse = TRUE.
policy_eval() implements three types of policy
evaluations: Inverse probability weighting estimation, outcome
regression estimation, and doubly robust (DR) estimation. As doubly
robust estimation is a combination of the two other types, we focus on
this approach. For details on the implementation see Algorithm 1 in
(Nordland and Holst 2023).
## Estimate Std.Err 2.5% 97.5% P-value
## E[Z(d)]: d=(A=1) 0.8213 0.1115 0.6027 1.04 1.796e-13
policy_eval() returns an object of type
policy_eval which prints like a lava::estimate
object. The policy value estimate and variance are available via
coef() and vcov():
## [1] 0.8213233
## [,1]
## [1,] 0.01244225
policy_eval objectsThe policy_eval object behaves like an
lava::estimate object, which can also be directly accessed
using estimate().
estimate objects makes it easy to work with estimates
with an iid decomposition given by the influence curve/function, see the
estimate
vignette.
The influence curve is available via IC():
## [,1]
## [1,] 2.5515875
## [2,] -5.6787782
## [3,] 4.9506000
## [4,] 2.0661524
## [5,] 0.7939672
## [6,] -2.2932160
Merging estimate objects allow the user to get inference
for transformations of the estimates via the Delta method. Here we get
inference for the average treatment effect, both as a difference and as
a ratio:
p0 <- policy_def(policy_functions = 0, reuse = TRUE, name = "(A=0)")
pe0 <- policy_eval(policy_data = pd,
policy = p0,
type = "dr")
(est <- merge(pe0, pe1))## Estimate Std.Err 2.5% 97.5% P-value
## E[Z(d)]: d=(A=0) -0.06123 0.0881 -0.2339 0.1114 4.871e-01
## ────────────────
## E[Z(d)]: d=(A=1) 0.82132 0.1115 0.6027 1.0399 1.796e-13
## Estimate Std.Err 2.5% 97.5% P-value
## ATE-difference 0.8825 0.1338 0.6203 1.145 4.25e-11
## Estimate Std.Err 2.5% 97.5% P-value
## ATE-ratio -13.41 19.6 -51.83 25 0.4937
So far we have relied on the default generalized linear models for
the nuisance g-models and Q-models. As default, a single g-model trained
across all stages using the state/Markov type history, see the
policy_data vignette. Use get_g_functions() to
get access to the fitted model:
## $all_stages
## $model
##
## Call: NULL
##
## Coefficients:
## (Intercept) L C B BBgroup2 BBgroup3
## 0.08285 0.03094 0.97993 -0.05753 -0.13970 -0.06122
##
## Degrees of Freedom: 3999 Total (i.e. Null); 3994 Residual
## Null Deviance: 5518
## Residual Deviance: 4356 AIC: 4368
##
##
## attr(,"full_history")
## [1] FALSE
The g-functions can be used as input to a new policy evaluation:
## Estimate Std.Err 2.5% 97.5% P-value
## E[Z(d)]: d=(A=0) -0.06123 0.0881 -0.2339 0.1114 0.4871
or we can get the associated predicted values:
## Key: <id, stage>
## id stage g_0 g_1
## <int> <int> <num> <num>
## 1: 1 1 0.15628741 0.84371259
## 2: 1 2 0.08850558 0.91149442
## 3: 2 1 0.92994454 0.07005546
## 4: 2 2 0.92580890 0.07419110
## 5: 3 1 0.11184451 0.88815549
## 6: 3 2 0.08082666 0.91917334
Similarly, we can inspect the Q-functions using
get_q_functions():
## $stage_1
## $model
##
## Call: NULL
##
## Coefficients:
## (Intercept) A1 L C B BBgroup2
## 0.232506 0.682422 0.454642 0.039021 -0.070152 -0.184704
## BBgroup3 A1:L A1:C A1:B A1:BBgroup2 A1:BBgroup3
## -0.171734 -0.010746 0.938791 0.003772 0.157200 0.270711
##
## Degrees of Freedom: 1999 Total (i.e. Null); 1988 Residual
## Null Deviance: 7689
## Residual Deviance: 3599 AIC: 6877
##
##
## $stage_2
## $model
##
## Call: NULL
##
## Coefficients:
## (Intercept) A1 L C B BBgroup2
## -0.043324 0.147356 0.002376 -0.042036 0.005331 -0.001128
## BBgroup3 A1:L A1:C A1:B A1:BBgroup2 A1:BBgroup3
## -0.108404 0.024424 0.962591 -0.059177 -0.102084 0.094688
##
## Degrees of Freedom: 1999 Total (i.e. Null); 1988 Residual
## Null Deviance: 3580
## Residual Deviance: 1890 AIC: 5588
##
##
## attr(,"full_history")
## [1] FALSE
Note that a model is trained for each stage. Again, we can predict
from the Q-models using predict().
Usually, we want to specify the nuisance models ourselves using the
g_models and q_models arguments:
pe1 <- policy_eval(pd,
policy = p1,
g_models = list(
g_sl(formula = ~ BB + L_1, SL.library = c("SL.glm", "SL.ranger")),
g_sl(formula = ~ BB + L_1 + C_2, SL.library = c("SL.glm", "SL.ranger"))
),
g_full_history = TRUE,
q_models = list(
q_glm(formula = ~ A * (B + C_1)), # including action interactions
q_glm(formula = ~ A * (B + C_1 + C_2)) # including action interactions
),
q_full_history = TRUE)## Loading required namespace: ranger
Here we train a super learner g-model for each stage using the full
available history and a generalized linear model for the Q-models. The
formula argument is used to construct the model frame
passed to the model for training (and prediction). The valid formula
terms depending on g_full_history and
q_full_history are available via
get_history_names():
## [1] "L" "C" "B" "BB"
## [1] "L_1" "C_1" "B" "BB"
## [1] "A_1" "L_1" "L_2" "C_1" "C_2" "B" "BB"
Remember that the action variable at the current stage is always
named A. Some models like glm require
interactions to be specified via the model frame. Thus, for some models,
it is important to include action interaction terms for the
Q-models.
The value of a learned policy is an important performance measure,
and policy_eval() allow for direct evaluation of a given
policy learning algorithm. For details, see Algorithm 4 in (Nordland and Holst 2023).
In polle, policy learning algorithms are specified using
policy_learn(), see the associated vignette. These
functions can be directly evaluated in policy_eval():
## Estimate Std.Err 2.5% 97.5% P-value
## E[Z(d)]: d=ql 1.306 0.06641 1.176 1.437 3.783e-86
In the above example we evaluate the policy estimated via Q-learning.
Alternatively, we can first learn the policy and then pass it to
policy_eval():
p_ql <- policy_learn(type = "ql")(pd, q_models = q_glm())
policy_eval(pd,
policy = get_policy(p_ql))## Estimate Std.Err 2.5% 97.5% P-value
## E[Z(d)]: d=ql 1.306 0.06641 1.176 1.437 3.783e-86
A key feature of policy_eval() is that it allows for
easy cross-fitting of the nuisance models as well the learned policy.
Here we specify two-fold cross-fitting via the M
argument:
Specifically, both the nuisance models and the optimal policy are fitted on each training fold. Subsequently, the doubly robust value score is calculated on the validation folds.
The policy_eval object now consists of a list of
policy_eval objects associated with each fold:
## [1] 3 4 5 7 8 10
## Estimate Std.Err 2.5% 97.5% P-value
## E[Z(d)]: d=ql 1.261 0.09456 1.075 1.446 1.538e-40
In order to save memory, particularly when cross-fitting, it is
possible not to save the nuisance models via the
save_g_functions and save_q_functions
arguments.
future.applyIt is easy to parallelize the cross-fitting procedure via the
future.apply package:
## R version 4.4.1 (2024-06-14)
## Platform: aarch64-apple-darwin23.5.0
## Running under: macOS Sonoma 14.6.1
##
## Matrix products: default
## BLAS: /Users/oano/.asdf/installs/R/4.4.1/lib/R/lib/libRblas.dylib
## LAPACK: /Users/oano/.asdf/installs/R/4.4.1/lib/R/lib/libRlapack.dylib; LAPACK version 3.12.0
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## time zone: Europe/Copenhagen
## tzcode source: internal
##
## attached base packages:
## [1] splines stats graphics grDevices utils datasets methods
## [8] base
##
## other attached packages:
## [1] ggplot2_3.5.1 data.table_1.15.4 polle_1.5
## [4] SuperLearner_2.0-29 gam_1.22-4 foreach_1.5.2
## [7] nnls_1.5
##
## loaded via a namespace (and not attached):
## [1] sass_0.4.9 utf8_1.2.4 future_1.33.2
## [4] lattice_0.22-6 listenv_0.9.1 digest_0.6.36
## [7] magrittr_2.0.3 evaluate_0.24.0 grid_4.4.1
## [10] iterators_1.0.14 mvtnorm_1.2-5 policytree_1.2.3
## [13] fastmap_1.2.0 jsonlite_1.8.8 Matrix_1.7-0
## [16] survival_3.6-4 fansi_1.0.6 scales_1.3.0
## [19] numDeriv_2016.8-1.1 codetools_0.2-20 jquerylib_0.1.4
## [22] lava_1.8.0 cli_3.6.3 rlang_1.1.4
## [25] mets_1.3.4 parallelly_1.37.1 future.apply_1.11.2
## [28] munsell_0.5.1 withr_3.0.0 cachem_1.1.0
## [31] yaml_2.3.8 tools_4.4.1 parallel_4.4.1
## [34] colorspace_2.1-0 ranger_0.16.0 globals_0.16.3
## [37] vctrs_0.6.5 R6_2.5.1 lifecycle_1.0.4
## [40] pkgconfig_2.0.3 timereg_2.0.5 progressr_0.14.0
## [43] bslib_0.7.0 pillar_1.9.0 gtable_0.3.5
## [46] Rcpp_1.0.13 glue_1.7.0 xfun_0.45
## [49] tibble_3.2.1 highr_0.11 knitr_1.47
## [52] farver_2.1.2 htmltools_0.5.8.1 rmarkdown_2.27
## [55] labeling_0.4.3 compiler_4.4.1