Skip to contents

Uses a pre-trained ARF model to estimate leaf and distribution parameters.

Usage

forde(
  arf,
  x,
  oob = FALSE,
  family = "truncnorm",
  finite_bounds = FALSE,
  alpha = 0,
  epsilon = 0,
  parallel = TRUE
)

Arguments

arf

Pre-trained adversarial_rf. Alternatively, any object of class ranger.

x

Training data for estimating parameters.

oob

Only use out-of-bag samples for parameter estimation? If TRUE, x must be the same dataset used to train arf.

family

Distribution to use for density estimation of continuous features. Current options include truncated normal (the default family = "truncnorm") and uniform (family = "unif"). See Details.

finite_bounds

Impose finite bounds on all continuous variables?

alpha

Optional pseudocount for Laplace smoothing of categorical features. This avoids zero-mass points when test data fall outside the support of training data. Effectively parametrizes a flat Dirichlet prior on multinomial likelihoods.

epsilon

Optional slack parameter on empirical bounds when family = "unif" or finite_bounds = TRUE. This avoids zero-density points when test data fall outside the support of training data. The gap between lower and upper bounds is expanded by a factor of 1 + epsilon.

parallel

Compute in parallel? Must register backend beforehand, e.g. via doParallel.

Value

A list with 5 elements: (1) parameters for continuous data; (2) parameters for discrete data; (3) leaf indices and coverage; (4) metadata on variables; and (5) the data input class. This list is used for estimating likelihoods with lik and generating data with forge.

Details

forde extracts leaf parameters from a pretrained forest and learns distribution parameters for data within each leaf. The former includes coverage (proportion of data falling into the leaf) and split criteria. The latter includes proportions for categorical features and mean/variance for continuous features. The result is a probabilistic circuit, stored as a data.table, which can be used for various downstream inference tasks.

Currently, forde only provides support for a limited number of distributional families: truncated normal or uniform for continuous data, and multinomial for discrete data. Future releases will accommodate a larger set of options.

Though forde was designed to take an adversarial random forest as input, the function's first argument can in principle be any object of class ranger. This allows users to test performance with alternative pipelines (e.g., with supervised forest input). There is also no requirement that x be the data used to fit arf, unless oob = TRUE. In fact, using another dataset here may protect against overfitting. This connects with Wager & Athey's (2018) notion of "honest trees".

References

Watson, D., Blesch, K., Kapar, J., & Wright, M. (2023). Adversarial random forests for density estimation and generative modeling. In Proceedings of the 26th International Conference on Artificial Intelligence and Statistics, pp. 5357-5375.

Wager, S. & Athey, S. (2018). Estimation and inference of heterogeneous treatment effects using random forests. J. Am. Stat. Assoc., 113(523): 1228-1242.

See also

Examples

arf <- adversarial_rf(iris)
#> Iteration: 0, Accuracy: 86.05%
#> Iteration: 1, Accuracy: 45.67%
psi <- forde(arf, iris)
head(psi)
#> $cnt
#>      f_idx     variable  min  max   mu      sigma
#>   1:     1 Petal.Length -Inf 1.55 1.35 0.07071068
#>   2:     1  Petal.Width -Inf  Inf 0.25 0.07071068
#>   3:     1 Sepal.Length -Inf  Inf 4.45 0.07071068
#>   4:     1  Sepal.Width -Inf 2.95 2.60 0.42426407
#>   5:     2 Petal.Length 4.05 4.30 4.15 0.07071068
#>  ---                                             
#> 720:   180  Sepal.Width 2.90 3.45 3.40 0.11456157
#> 721:   181 Petal.Length -Inf  Inf 4.96 0.13416408
#> 722:   181  Petal.Width 1.65  Inf 1.96 0.26076810
#> 723:   181 Sepal.Length -Inf 6.15 5.88 0.19235384
#> 724:   181  Sepal.Width 2.75  Inf 2.92 0.10954451
#> 
#> $cat
#>      f_idx variable        val prob
#>   1:     1  Species     setosa  1.0
#>   2:     2  Species versicolor  1.0
#>   3:     3  Species versicolor  1.0
#>   4:     4  Species versicolor  0.5
#>   5:     4  Species  virginica  0.5
#>  ---                               
#> 211:   177  Species versicolor  1.0
#> 212:   178  Species     setosa  1.0
#> 213:   179  Species  virginica  1.0
#> 214:   180  Species     setosa  1.0
#> 215:   181  Species  virginica  1.0
#> 
#> $forest
#>      f_idx tree leaf        cvg
#>   1:     1    1   14 0.01333333
#>   2:     2    1   20 0.01333333
#>   3:     3    1   26 0.10000000
#>   4:     4    1   28 0.01333333
#>   5:     5    1   29 0.02000000
#>  ---                           
#> 177:   177   10   48 0.03333333
#> 178:   178   10   50 0.15333333
#> 179:   179   10   52 0.02666667
#> 180:   180   10   55 0.02000000
#> 181:   181   10   56 0.03333333
#> 
#> $meta
#>        variable   class    family decimals
#> 1: Sepal.Length numeric truncnorm        1
#> 2:  Sepal.Width numeric truncnorm        1
#> 3: Petal.Length numeric truncnorm        1
#> 4:  Petal.Width numeric truncnorm        1
#> 5:      Species  factor  multinom       NA
#> 
#> $levels
#>    variable        val
#> 1:  Species  virginica
#> 2:  Species     setosa
#> 3:  Species versicolor
#> 
#> $input_class
#> [1] "data.frame"
#>