Compute partial dependence for an oblique random forest. Partial dependence (PD) shows the expected prediction from a model as a function of a single predictor or multiple predictors. The expectation is marginalized over the values of all other predictors, giving something like a multivariable adjusted estimate of the model's prediction. You can compute partial dependence three ways using a random forest:
using in-bag predictions for the training data
using out-of-bag predictions for the training data
using predictions for a new set of data
See examples for more details
Usage
orsf_pd_oob(
object,
pred_spec,
pred_horizon = NULL,
pred_type = NULL,
expand_grid = TRUE,
prob_values = c(0.025, 0.5, 0.975),
prob_labels = c("lwr", "medn", "upr"),
boundary_checks = TRUE,
n_thread = NULL,
verbose_progress = NULL,
...
)
orsf_pd_inb(
object,
pred_spec,
pred_horizon = NULL,
pred_type = NULL,
expand_grid = TRUE,
prob_values = c(0.025, 0.5, 0.975),
prob_labels = c("lwr", "medn", "upr"),
boundary_checks = TRUE,
n_thread = NULL,
verbose_progress = NULL,
...
)
orsf_pd_new(
object,
pred_spec,
new_data,
pred_horizon = NULL,
pred_type = NULL,
na_action = "fail",
expand_grid = TRUE,
prob_values = c(0.025, 0.5, 0.975),
prob_labels = c("lwr", "medn", "upr"),
boundary_checks = TRUE,
n_thread = NULL,
verbose_progress = NULL,
...
)
Arguments
- object
(ObliqueForest) a trained oblique random forest object (see orsf).
- pred_spec
(named list, pspec_auto, or data.frame).
If
pred_spec
is a named list, Each item in the list should be a vector of values that will be used as points in the partial dependence function. The name of each item in the list should indicate which variable will be modified to take the corresponding values.If
pred_spec
is created usingpred_spec_auto()
, all that is needed is the names of variables to use (see pred_spec_auto).If
pred_spec
is adata.frame
, columns will indicate variable names, values will indicate variable values, and partial dependence will be computed using the inputs on each row.
- pred_horizon
(double) Only relevent for survival forests. A value or vector indicating the time(s) that predictions will be calibrated to. E.g., if you were predicting risk of incident heart failure within the next 10 years, then
pred_horizon = 10
.pred_horizon
can beNULL
ifpred_type
is'mort'
, since mortality predictions are aggregated over all event times- pred_type
(character) the type of predictions to compute. Valid Valid options for survival are:
'risk' : probability of having an event at or before
pred_horizon
.'surv' : 1 - risk.
'chf': cumulative hazard function
'mort': mortality prediction
'time': survival time prediction
For classification:
'prob': probability for each class
For regression:
'mean': predicted mean, i.e., the expected value
- expand_grid
(logical) if
TRUE
, partial dependence will be computed at all possible combinations of inputs inpred_spec
. IfFALSE
, partial dependence will be computed for each variable inpred_spec
, separately.- prob_values
(numeric) a vector of values between 0 and 1, indicating what quantiles will be used to summarize the partial dependence values at each set of inputs.
prob_values
should have the same length asprob_labels
. The quantiles are calculated based on predictions fromobject
at each set of values indicated bypred_spec
.- prob_labels
(character) a vector of labels with the same length as
prob_values
, with each label indicating what the corresponding value inprob_values
should be labelled as in summarized outputs.prob_labels
should have the same length asprob_values
.- boundary_checks
(logical) if
TRUE
,pred_spec
will be checked to make sure the requested values are between the 10th and 90th percentile in the object's training data. IfFALSE
, these checks are skipped.- n_thread
(integer) number of threads to use while computing predictions. Default is 0, which allows a suitable number of threads to be used based on availability.
- verbose_progress
(logical) if
TRUE
, progress will be printed to console. IfFALSE
(the default), nothing will be printed.- ...
Further arguments passed to or from other methods (not currently used).
- new_data
a data.frame, tibble, or data.table to compute predictions in.
- na_action
(character) what should happen when
new_data
contains missing values (i.e.,NA
values). Valid options are:'fail' : an error is thrown if
new_data
containsNA
values'omit' : rows in
new_data
with incomplete data will be dropped
Value
a data.table containing partial dependence values for the specified variable(s) and, if relevant, at the specified prediction horizon(s).
Details
Partial dependence has a number of known limitations and assumptions that users should be aware of (see Hooker, 2021). In particular, partial dependence is less intuitive when >2 predictors are examined jointly, and it is assumed that the feature(s) for which the partial dependence is computed are not correlated with other features (this is likely not true in many cases). Accumulated local effect plots can be used (see here) in the case where feature independence is not a valid assumption.
Examples
You can compute partial dependence and individual conditional expectations in three ways:
using in-bag predictions for the training data. In-bag partial dependence indicates relationships that the model has learned during training. This is helpful if your goal is to interpret the model.
using out-of-bag predictions for the training data. Out-of-bag partial dependence indicates relationships that the model has learned during training but using the out-of-bag data simulates application of the model to new data. This is helpful if you want to test your model’s reliability or fairness in new data but you don’t have access to a large testing set.
using predictions for a new set of data. New data partial dependence shows how the model predicts outcomes for observations it has not seen. This is helpful if you want to test your model’s reliability or fairness.
Classification
Begin by fitting an oblique classification random forest:
set.seed(329)
index_train <- sample(nrow(penguins_orsf), 150)
penguins_orsf_train <- penguins_orsf[index_train, ]
penguins_orsf_test <- penguins_orsf[-index_train, ]
fit_clsf <- orsf(data = penguins_orsf_train,
formula = species ~ .)
Compute partial dependence using out-of-bag data for
flipper_length_mm = c(190, 210)
.
pred_spec <- list(flipper_length_mm = c(190, 210))
pd_oob <- orsf_pd_oob(fit_clsf, pred_spec = pred_spec)
pd_oob
## Key: <class>
## class flipper_length_mm mean lwr medn upr
## <fctr> <num> <num> <num> <num> <num>
## 1: Adelie 190 0.6180632 0.207463688 0.76047056 0.9809703
## 2: Adelie 210 0.4346177 0.018583256 0.56486883 0.8647387
## 3: Chinstrap 190 0.2119948 0.017692341 0.15658268 0.7163635
## 4: Chinstrap 210 0.1801186 0.020454479 0.09525310 0.7085293
## 5: Gentoo 190 0.1699420 0.001277844 0.02831331 0.5738689
## 6: Gentoo 210 0.3852637 0.068685035 0.20853993 0.9537020
Note that predicted probabilities are returned for each class and
probabilities in the mean
column sum to 1 if you take the sum over
each class at a specific value of the pred_spec
variables. For
example,
sum(pd_oob[flipper_length_mm == 190, mean])
But this isn’t the case for the median predicted probability!
sum(pd_oob[flipper_length_mm == 190, medn])
Regression
Begin by fitting an oblique regression random forest:
set.seed(329)
index_train <- sample(nrow(penguins_orsf), 150)
penguins_orsf_train <- penguins_orsf[index_train, ]
penguins_orsf_test <- penguins_orsf[-index_train, ]
fit_regr <- orsf(data = penguins_orsf_train,
formula = bill_length_mm ~ .)
Compute partial dependence using new data for
flipper_length_mm = c(190, 210)
.
pred_spec <- list(flipper_length_mm = c(190, 210))
pd_new <- orsf_pd_new(fit_regr,
pred_spec = pred_spec,
new_data = penguins_orsf_test)
pd_new
## flipper_length_mm mean lwr medn upr
## <num> <num> <num> <num> <num>
## 1: 190 42.96571 37.09805 43.69769 48.72301
## 2: 210 45.66012 40.50693 46.31577 51.65163
You can also let pred_spec_auto
pick reasonable values like so:
pred_spec = pred_spec_auto(species, island, body_mass_g)
pd_new <- orsf_pd_new(fit_regr,
pred_spec = pred_spec,
new_data = penguins_orsf_test)
pd_new
## species island body_mass_g mean lwr medn upr
## <fctr> <fctr> <num> <num> <num> <num> <num>
## 1: Adelie Biscoe 3200 40.31374 37.24373 40.31967 44.22824
## 2: Chinstrap Biscoe 3200 45.10582 42.63342 45.10859 47.60119
## 3: Gentoo Biscoe 3200 42.81649 40.19221 42.55664 46.84035
## 4: Adelie Dream 3200 40.16219 36.95895 40.34633 43.90681
## 5: Chinstrap Dream 3200 46.21778 43.53954 45.90929 49.19173
## 6: Gentoo Dream 3200 42.60465 39.89647 42.63520 46.28769
## 7: Adelie Torgersen 3200 39.91652 36.80227 39.79806 43.68842
## 8: Chinstrap Torgersen 3200 44.27807 41.95470 44.40742 46.68848
## 9: Gentoo Torgersen 3200 42.09510 39.49863 41.80049 45.81833
## 10: Adelie Biscoe 3550 40.77971 38.04027 40.59561 44.57505
## 11: Chinstrap Biscoe 3550 45.81304 43.52102 45.73116 48.36366
## 12: Gentoo Biscoe 3550 43.31233 40.77355 43.03077 47.22936
## 13: Adelie Dream 3550 40.77741 38.07399 40.78175 44.37273
## 14: Chinstrap Dream 3550 47.30926 44.80493 46.77540 50.47092
## 15: Gentoo Dream 3550 43.26955 40.86119 43.16204 46.89190
## 16: Adelie Torgersen 3550 40.25780 37.35251 40.07871 44.04576
## 17: Chinstrap Torgersen 3550 44.77911 42.60161 44.81944 47.14986
## 18: Gentoo Torgersen 3550 42.49520 39.95866 42.14160 46.26237
## 19: Adelie Biscoe 3975 41.61744 38.94515 41.36634 45.38752
## 20: Chinstrap Biscoe 3975 46.59363 44.59970 46.44923 49.11457
## 21: Gentoo Biscoe 3975 44.07857 41.60792 43.74562 47.85109
## 22: Adelie Dream 3975 41.50511 39.06187 41.24741 45.13027
## 23: Chinstrap Dream 3975 48.14978 45.87390 47.54867 51.50683
## 24: Gentoo Dream 3975 44.01928 41.70577 43.84099 47.50470
## 25: Adelie Torgersen 3975 40.94764 38.12519 40.66759 44.73689
## 26: Chinstrap Torgersen 3975 45.44820 43.49986 45.44036 47.63243
## 27: Gentoo Torgersen 3975 43.13791 40.70628 42.70627 46.87306
## 28: Adelie Biscoe 4700 42.93914 40.48463 42.44768 46.81756
## 29: Chinstrap Biscoe 4700 47.18534 45.40866 47.07739 49.55747
## 30: Gentoo Biscoe 4700 45.32541 43.08173 44.93498 49.23391
## 31: Adelie Dream 4700 42.73806 40.44229 42.22226 46.49936
## 32: Chinstrap Dream 4700 48.37354 46.34335 48.00781 51.18955
## 33: Gentoo Dream 4700 45.09132 42.88328 44.79530 48.82180
## 34: Adelie Torgersen 4700 42.09349 39.72074 41.56168 45.68838
## 35: Chinstrap Torgersen 4700 46.17045 44.39042 46.09525 48.35127
## 36: Gentoo Torgersen 4700 44.31621 42.18968 43.81773 47.98024
## 37: Adelie Biscoe 5300 43.89769 41.43335 43.28504 48.10892
## 38: Chinstrap Biscoe 5300 47.53721 45.66038 47.52770 49.88701
## 39: Gentoo Biscoe 5300 46.16115 43.81722 45.59309 50.57469
## 40: Adelie Dream 5300 43.59846 41.25825 43.24518 47.46193
## 41: Chinstrap Dream 5300 48.48139 46.36282 48.25679 51.02996
## 42: Gentoo Dream 5300 45.91819 43.62832 45.54110 49.91622
## 43: Adelie Torgersen 5300 42.92879 40.66576 42.31072 46.76406
## 44: Chinstrap Torgersen 5300 46.59576 44.80400 46.49196 49.03906
## 45: Gentoo Torgersen 5300 45.11384 42.95190 44.51289 49.27629
## species island body_mass_g mean lwr medn upr
By default, all combinations of all variables are used. However, you can also look at the variables one by one, separately, like so:
pd_new <- orsf_pd_new(fit_regr,
expand_grid = FALSE,
pred_spec = pred_spec,
new_data = penguins_orsf_test)
pd_new
## variable value level mean lwr medn upr
## <char> <num> <char> <num> <num> <num> <num>
## 1: species NA Adelie 41.90271 37.10417 41.51723 48.51478
## 2: species NA Chinstrap 47.11314 42.40419 46.96478 51.51392
## 3: species NA Gentoo 44.37038 39.87306 43.89889 51.21635
## 4: island NA Biscoe 44.21332 37.22711 45.27862 51.21635
## 5: island NA Dream 44.43354 37.01471 45.57261 51.51392
## 6: island NA Torgersen 43.29539 37.01513 44.26924 49.84391
## 7: body_mass_g 3200 <NA> 42.84625 37.03978 43.95991 49.19173
## 8: body_mass_g 3550 <NA> 43.53326 37.56730 44.43756 50.47092
## 9: body_mass_g 3975 <NA> 44.30431 38.31567 45.22089 51.50683
## 10: body_mass_g 4700 <NA> 45.22559 39.88199 46.34680 51.18955
## 11: body_mass_g 5300 <NA> 45.91412 40.84742 46.95327 51.48851
And you can also bypass all the bells and whistles by using your own
data.frame
for a pred_spec
. (Just make sure you request values that
exist in the training data.)
custom_pred_spec <- data.frame(species = 'Adelie',
island = 'Biscoe')
pd_new <- orsf_pd_new(fit_regr,
pred_spec = custom_pred_spec,
new_data = penguins_orsf_test)
pd_new
Survival
Begin by fitting an oblique survival random forest:
set.seed(329)
index_train <- sample(nrow(pbc_orsf), 150)
pbc_orsf_train <- pbc_orsf[index_train, ]
pbc_orsf_test <- pbc_orsf[-index_train, ]
fit_surv <- orsf(data = pbc_orsf_train,
formula = Surv(time, status) ~ . - id,
oobag_pred_horizon = 365.25 * 5)
Compute partial dependence using in-bag data for bili = c(1,2,3,4,5)
:
pd_train <- orsf_pd_inb(fit_surv, pred_spec = list(bili = 1:5))
pd_train
## pred_horizon bili mean lwr medn upr
## <num> <num> <num> <num> <num> <num>
## 1: 1826.25 1 0.2566200 0.02234786 0.1334170 0.8918909
## 2: 1826.25 2 0.3121392 0.06853733 0.1896849 0.9204338
## 3: 1826.25 3 0.3703242 0.11409793 0.2578505 0.9416791
## 4: 1826.25 4 0.4240692 0.15645214 0.3331057 0.9591581
## 5: 1826.25 5 0.4663670 0.20123406 0.3841700 0.9655296
If you don’t have specific values of a variable in mind, let
pred_spec_auto
pick for you:
pd_train <- orsf_pd_inb(fit_surv, pred_spec_auto(bili))
pd_train
## pred_horizon bili mean lwr medn upr
## <num> <num> <num> <num> <num> <num>
## 1: 1826.25 0.590 0.2484695 0.02035041 0.1243120 0.8823385
## 2: 1826.25 0.725 0.2508045 0.02060111 0.1274237 0.8836536
## 3: 1826.25 1.500 0.2797763 0.03964900 0.1601715 0.9041584
## 4: 1826.25 3.500 0.3959349 0.13431288 0.2920400 0.9501230
## 5: 1826.25 7.210 0.5344511 0.27869513 0.4651185 0.9782084
Specify pred_horizon
to get partial dependence at each value:
pd_train <- orsf_pd_inb(fit_surv, pred_spec_auto(bili),
pred_horizon = seq(500, 3000, by = 500))
pd_train
## pred_horizon bili mean lwr medn upr
## <num> <num> <num> <num> <num> <num>
## 1: 500 0.590 0.06184375 0.0004433990 0.008765301 0.5918852
## 2: 1000 0.590 0.14210619 0.0057937418 0.056124198 0.7381107
## 3: 1500 0.590 0.20859307 0.0136094784 0.091808079 0.8577223
## 4: 2000 0.590 0.26823465 0.0230476894 0.145707217 0.8918696
## 5: 2500 0.590 0.31809404 0.0631155452 0.202189830 0.9035026
## 6: 3000 0.590 0.39152139 0.0911566314 0.302738552 0.9239861
## 7: 500 0.725 0.06255088 0.0004462367 0.008934806 0.5980510
## 8: 1000 0.725 0.14337233 0.0063321712 0.056348007 0.7447805
## 9: 1500 0.725 0.21058059 0.0140736894 0.093113771 0.8597396
## 10: 2000 0.725 0.27056356 0.0235448705 0.146307939 0.8941464
## 11: 2500 0.725 0.31922691 0.0626303822 0.202462648 0.9073970
## 12: 3000 0.725 0.39426313 0.0911457406 0.308440546 0.9252028
## 13: 500 1.500 0.06679162 0.0012717884 0.011028398 0.6241228
## 14: 1000 1.500 0.15727919 0.0114789623 0.068332010 0.7678732
## 15: 1500 1.500 0.23316655 0.0287320952 0.117289745 0.8789647
## 16: 2000 1.500 0.30139227 0.0467927208 0.180096425 0.9144202
## 17: 2500 1.500 0.35260943 0.0845866747 0.238015966 0.9266065
## 18: 3000 1.500 0.43512074 0.1311103304 0.346025144 0.9438562
## 19: 500 3.500 0.08638646 0.0052087533 0.028239001 0.6740930
## 20: 1000 3.500 0.22353655 0.0519179775 0.139604845 0.8283986
## 21: 1500 3.500 0.32700976 0.0901983241 0.217982772 0.9371150
## 22: 2000 3.500 0.41618105 0.1445328597 0.311508093 0.9566091
## 23: 2500 3.500 0.49248461 0.2195110942 0.402095677 0.9636221
## 24: 3000 3.500 0.56008108 0.2635698957 0.503253258 0.9734948
## 25: 500 7.210 0.12550962 0.0220920570 0.063425987 0.7526581
## 26: 1000 7.210 0.32567558 0.1353851175 0.259047345 0.8875150
## 27: 1500 7.210 0.46327019 0.2181840827 0.386681920 0.9700903
## 28: 2000 7.210 0.55042753 0.2912654769 0.483477295 0.9812223
## 29: 2500 7.210 0.61937483 0.3709845684 0.567895754 0.9844945
## 30: 3000 7.210 0.67963922 0.4247511750 0.645083041 0.9888637
## pred_horizon bili mean lwr medn upr
vector-valued pred_horizon
input comes with minimal extra
computational cost. Use a fine grid of time values and assess whether
predictors have time-varying effects. (see partial dependence vignette
for example)