A general test for conditional independence in supervised learning algorithms. Implements a conditional variable importance measure which can be applied to any supervised learning algorithm and loss function. Provides statistical inference procedures without parametric assumptions and applies equally well to continuous and categorical predictors and outcomes.
Usage
cpi(
task,
learner,
resampling = NULL,
test_data = NULL,
measure = NULL,
test = "t",
log = FALSE,
B = 1999,
alpha = 0.05,
x_tilde = NULL,
knockoff_fun = function(x) knockoff::create.second_order(as.matrix(x)),
groups = NULL,
verbose = FALSE
)
Arguments
- task
The prediction
mlr3
task, see examples.- learner
The
mlr3
learner used in CPI. If you pass a string, the learner will be created viamlr3::lrn
.- resampling
Resampling strategy,
mlr3
resampling object (e.g.rsmp("holdout")
), "oob" (out-of-bag) or "none" (in-sample loss).- test_data
External validation data, use instead of resampling.
- measure
Performance measure (loss). Per default, use MSE (
"regr.mse"
) for regression and logloss ("classif.logloss"
) for classification.- test
Statistical test to perform, one of
"t"
(t-test, default),"wilcox"
(Wilcoxon signed-rank test),"binom"
(binomial test),"fisher"
(Fisher permutation test) or "bayes" (Bayesian testing, computationally intensive!). See Details.- log
Set to
TRUE
for multiplicative CPI (\(\lambda\)), toFALSE
(default) for additive CPI (\(\Delta\)).- B
Number of permutations for Fisher permutation test.
- alpha
Significance level for confidence intervals.
- x_tilde
Knockoff matrix or data.frame. If not given (the default), it will be created with the function given in
knockoff_fun
.- knockoff_fun
Function to generate knockoffs. Default:
knockoff::create.second_order
with matrix argument.- groups
(Named) list with groups. Set to
NULL
(default) for no groups, i.e. compute CPI for each feature. See examples.- verbose
Verbose output of resampling procedure.
Value
For test = "bayes"
a list of BEST
objects. In any other
case, a data.frame
with a row for each feature and columns:
- Variable/Group
Variable/group name
- CPI
CPI value
- SE
Standard error
- test
Testing method
- statistic
Test statistic (only for t-test, Wilcoxon and binomial test)
- estimate
Estimated mean (for t-test), median (for Wilcoxon test), or proportion of \(\Delta\)-values greater than 0 (for binomial test).
- p.value
p-value
- ci.lo
Lower limit of (1 -
alpha
) * 100% confidence interval
Note that NA values are no error but a result of a CPI value of 0, i.e. no difference in model performance after replacing a feature with its knockoff.
Details
This function computes the conditional predictive impact (CPI) of one or several features on a given supervised learning task. This represents the mean error inflation when replacing a true variable with its knockoff. Large CPI values are evidence that the feature(s) in question have high conditional variable importance -- i.e., the fitted model relies on the feature(s) to predict the outcome, even after accounting for the signal from all remaining covariates.
We build on the mlr3
framework, which provides a unified interface for
training models, specifying loss functions, and estimating generalization
error. See the package documentation for more info.
Methods are implemented for frequentist and Bayesian inference. The default
is test = "t"
, which is fast and powerful for most sample sizes. The
Wilcoxon signed-rank test (test = "wilcox"
) may be more appropriate if
the CPI distribution is skewed, while the binomial test (test = "binom"
)
requires basically no assumptions but may have less power. For small sample
sizes, we recommend permutation tests (test = "fisher"
) or Bayesian
methods (test = "bayes"
). In the latter case, default priors are
assumed. See the BEST
package for more info.
For parallel execution, register a backend, e.g. with
doParallel::registerDoParallel()
.
References
Watson, D. & Wright, M. (2020). Testing conditional independence in supervised learning algorithms. Machine Learning, 110(8): 2107-2129. doi:10.1007/s10994-021-06030-6
Candès, E., Fan, Y., Janson, L, & Lv, J. (2018). Panning for gold: 'model-X' knockoffs for high dimensional controlled variable selection. J. R. Statistc. Soc. B, 80(3): 551-577. doi:10.1111/rssb.12265
Examples
library(mlr3)
library(mlr3learners)
# Regression with linear model and holdout validation
cpi(task = tsk("mtcars"), learner = lrn("regr.lm"),
resampling = rsmp("holdout"))
#> Variable CPI SE test statistic estimate p.value
#> 1 am -1.133348e+00 1.3959812163 t -0.8118647 -1.133348e+00 0.7821157
#> 2 carb -5.822104e-04 0.0001795117 t -3.2433012 -5.822104e-04 0.9955900
#> 3 cyl -4.321054e-02 0.2576547016 t -0.1677071 -4.321054e-02 0.5649216
#> 4 disp -4.202802e-05 0.0002564478 t -0.1638852 -4.202802e-05 0.5634568
#> 5 drat -1.619294e+00 1.6830802249 t -0.9621013 -1.619294e+00 0.8206572
#> 6 gear -3.031974e-01 0.6996023258 t -0.4333853 -3.031974e-01 0.6630329
#> 7 hp 4.263017e-01 0.8741337577 t 0.4876848 4.263017e-01 0.3181440
#> 8 qsec 1.798770e+00 1.6171852379 t 1.1122847 1.798770e+00 0.1460199
#> 9 vs 1.171381e-01 0.1329319081 t 0.8811889 1.171381e-01 0.1994499
#> 10 wt 3.735741e-01 0.7862931170 t 0.4751080 3.735741e-01 0.3224583
#> ci.lo
#> 1 -3.6635095494
#> 2 -0.0009075683
#> 3 -0.5101996653
#> 4 -0.0005068297
#> 5 -4.6698112078
#> 6 -1.5711993979
#> 7 -1.1580317415
#> 8 -1.1323149473
#> 9 -0.1237957921
#> 10 -1.0515515722
# \donttest{
# Classification with logistic regression, log-loss and t-test
cpi(task = tsk("wine"),
learner = lrn("classif.glmnet", predict_type = "prob", lambda = 0.1),
resampling = rsmp("holdout"),
measure = "classif.logloss", test = "t")
#> Variable CPI SE test statistic estimate
#> 1 alcalinity 0.000000e+00 0.000000e+00 t 0.0000000 0.000000e+00
#> 2 alcohol 4.177518e-02 1.408234e-02 t 2.9664952 4.177518e-02
#> 3 ash -7.276089e-05 4.319428e-05 t -1.6845027 -7.276089e-05
#> 4 color 3.743968e-03 1.019247e-02 t 0.3673270 3.743968e-03
#> 5 dilution 1.356989e-02 1.680332e-02 t 0.8075716 1.356989e-02
#> 6 flavanoids 4.007469e-06 6.786567e-06 t 0.5905002 4.007469e-06
#> 7 hue 5.096126e-03 5.359962e-03 t 0.9507764 5.096126e-03
#> 8 magnesium 0.000000e+00 0.000000e+00 t 0.0000000 0.000000e+00
#> 9 malic 0.000000e+00 0.000000e+00 t 0.0000000 0.000000e+00
#> 10 nonflavanoids 0.000000e+00 0.000000e+00 t 0.0000000 0.000000e+00
#> 11 phenols 0.000000e+00 0.000000e+00 t 0.0000000 0.000000e+00
#> 12 proanthocyanins 0.000000e+00 0.000000e+00 t 0.0000000 0.000000e+00
#> 13 proline 6.381548e-02 2.147100e-02 t 2.9721702 6.381548e-02
#> p.value ci.lo
#> 1 1.000000000 0.000000e+00
#> 2 0.002184548 1.823581e-02
#> 3 0.951270450 -1.449624e-04
#> 4 0.357356037 -1.329328e-02
#> 5 0.211318451 -1.451775e-02
#> 6 0.278574658 -7.336636e-06
#> 7 0.172831591 -3.863334e-03
#> 8 1.000000000 0.000000e+00
#> 9 1.000000000 0.000000e+00
#> 10 1.000000000 0.000000e+00
#> 11 1.000000000 0.000000e+00
#> 12 1.000000000 0.000000e+00
#> 13 0.002149876 2.792556e-02
# Use your own data (and out-of-bag loss with random forest)
mytask <- as_task_classif(iris, target = "Species")
mylearner <- lrn("classif.ranger", predict_type = "prob", keep.inbag = TRUE)
cpi(task = mytask, learner = mylearner,
resampling = "oob", measure = "classif.logloss")
#> Variable CPI SE test statistic estimate
#> 1 Petal.Length 2.862520e-04 0.0018381126 t 0.1557315 2.862520e-04
#> 2 Petal.Width 5.661919e-03 0.0040700859 t 1.3911056 5.661919e-03
#> 3 Sepal.Length -6.706178e-05 0.0003713225 t -0.1806026 -6.706178e-05
#> 4 Sepal.Width -6.236105e-04 0.0021435500 t -0.2909241 -6.236105e-04
#> p.value ci.lo
#> 1 0.43822774 -0.0027560901
#> 2 0.08313359 -0.0010746613
#> 3 0.57153752 -0.0006816541
#> 4 0.61424305 -0.0041714956
# Group CPI
cpi(task = tsk("iris"),
learner = lrn("classif.ranger", predict_type = "prob", num.trees = 10),
resampling = rsmp("cv", folds = 3),
groups = list(Sepal = 1:2, Petal = 3:4))
#> Group CPI SE test statistic estimate p.value
#> 1 Sepal 0.007383664 0.005088275 t 1.45111354 0.007383664 0.07442518
#> 2 Petal 0.000229964 0.003705083 t 0.06206717 0.000229964 0.47529626
#> ci.lo
#> 1 -0.001038166
#> 2 -0.005902484
# }
if (FALSE) {
# Bayesian testing
res <- cpi(task = tsk("iris"),
learner = lrn("classif.glmnet", predict_type = "prob", lambda = 0.1),
resampling = rsmp("holdout"),
measure = "classif.logloss", test = "bayes")
plot(res$Petal.Length)
# Parallel execution
doParallel::registerDoParallel()
cpi(task = tsk("wine"),
learner = lrn("classif.glmnet", predict_type = "prob", lambda = 0.1),
resampling = rsmp("cv", folds = 5))
# Use sequential knockoffs for categorical features
# package available here: https://github.com/kormama1/seqknockoff
mytask <- as_task_regr(iris, target = "Petal.Length")
cpi(task = mytask, learner = lrn("regr.ranger"),
resampling = rsmp("holdout"),
knockoff_fun = seqknockoff::knockoffs_seq)
}