XBART: Accelerated Bayesian Additive Regression Trees
Posted on
This post is based on He, J., Yalov, S., & Hahn, P. R. (2019). XBART: Accelerated Bayesian Additive Regression Trees. Proceedings of the Twenty-Second International Conference on Artificial Intelligence and Statistics, 1130–1138. https://proceedings.mlr.press/v89/he19a.html and He, J., & Hahn, P. R. (2023). Stochastic Tree Ensembles for Regularized Nonlinear Regression. Journal of the American Statistical Association, 118(541), 551–570. https://doi.org/10.1080/01621459.2021.1942012
- BART is not merely a version of RF. The Bayesian perspective leads to a fundamentally new tree growing criterion and algorithm.
BART model is an additive error mean regression model
\[y_i = f(x_i) + \epsilon_i\]The BART prior represents the unknown function $f(x)$ as a sum of many piecewise constant binary regression trees
\[f(x) = \sum_{l=1}^L g_l(x, T_l, \mu_l)\]where
- $T_l$: a regression tree
- $\mu_l$: a vector of scalar means associated to the leaf nodes of $T_l$
the tree prior $p(T_l)$ is specified by three components
- the probability of a node having children at depth $d$: $\alpha(1+d)^{-\beta}, \alpha\in (0, 1), \beta\in [0, +\infty)$
- the uniform distribution over available predictors for splitting rule assignment at each interior node
- the uniform distribution on the discrete set of available splitting values for the assigned predictor at each interior node
BART splitting criterion
the prior predictive distribution is simply a mean-zero multivariate normal distribution with covariance matrix
\[V = \tau JJ^\top + \sigma^2\bI\]The BART MCMC
the sequences of Gibbs updates are
- $T_l, \mu_l \mid r_l, \sigma^2$ a. $T_l \mid r_l, \sigma^2$ b. $\mu_l\mid T_l, r_l, \sigma^2$
- $\sigma^2 \mid r$
Step 1(a) is handled with a random walk: Given a current tree $T$, modifications are proposed an either accepted or rejected according to a likelihood ratio.
Chipman et al. (1998) describes proposals comprising a birth/death pair
XBART
Grow-from-root backfitting
rather than making small moves to a given tree at iteration $k+1$, ignore the current tree and grow an entirely new tree from scratch
- consider the no-split option, corresponding to a cut-point outside of the range of the available data
- with $C$ available active cut-points and $V$ total variables, perform $C\times V +1$ likelihood evaluations
Pre-sorting Features for Efficiency
the BART criterion depends on the partition sums only
an important implication of this, for computation, is that with sorted predictor variables, the various cut-point integrated likelihoods can be computed rapidly via a single sweep through the data (per variable)
Recursively Defined Cut-points
consider a restricted number of cut-points $C$
taking every $j$-th value as an eligible split point
Sparse Proposal Distribution
================
- a novel stochastic tree ensemble method for nonlinear regression
- combine regularization and stochastic search strategies from Bayesian modeling with computationally efficient techniques from recursive partitioning algorithms
- XBART provides accurate point-wise estimates of the mean function and does so faster than popular alternatives
- using XBART To initialize the standard BART MCMC algorithm considerably improves credible interval coverage and reduces total run-time
2. A Recursive, Stochastic Fitting Algorithm
2.1. Fitting a single tree recursively and stochastically
A tree $T$ is a set of split rules defining a rectangular partition of the covariate space to ${\cA_1,\ldots, \cA_B}$, where $B$ is the total number of terminal nodes of tree $T$
each rectangular cell $\cA_b$ is associated with leaf parameter $\mu_b$ and the pair $(T, \mu)$ parameterizes step function $g(\cdot)$ on covariate space
XBART is a modification of Algorithm 1 in which the partition cutpoints and also the stopping condition are determined stochastically
2.1.1 The XBART Marginal Likelihood Split Criterion
- $c_{jk}$: cutpoint. Each element of $X$ indexed as $c_{jk}$ where $j=1,\ldots, p$ indexes the column of $X$ and $k$ indexes a set of candidate cutpoints (row) of $X$
- $\vert \cal C\vert$: the total number of cutpoint candidates
- $\Phi$: prior hyper-parameters
- $\Psi$: model parameters
consider a likelihood $\ell(y_b, \mu_b, \Psi_b)$ on one leaf with
- vector of data observations with the leaf $y_b$
- leaf-specific parameter $\mu_b$
- additional model parameters $\Psi_b$
the leaf parameter $\mu$ is given a prior $\pi(\mu;\Phi)$, which induces a prior predictive distribution
\[m(y;\Phi, \Psi) = \int \ell(y;\mu, \Psi)\pi(\mu;\Phi) d\mu\]A cutpoint $c_{jk}$ partitions the current node to left and right child nodes, with (sub)vectors $y_{jk}^{(1)}$ and $y_{jk}^{(2)}$, respectively.
Assuming that observations in separate leaf nodes are independent, the joint prior predictive associated to this local Bayesian model is simply the produce of the predictive distribution in each of the two partitions defined by $c_{jk}$
\[L(c_{jk}) = m(y_{jk}^{(1)}; \Phi, \Psi)\cdot m(y_{jk}^{(2)}; \Phi, \Psi)\]which defines the split criterion for cutpoint $c_{jk}$.
the null cutpoint is defined as
\[L(\empty)\]2.2. Tree Ensembles
- it produces $I$ samples of the forest $\cF$
- a sweep: one iteration of the algorithm, sampling all $L$ trees once
Regression with XBART
\[Y = f(X) + \epsilon\]where $f$ is an unknown mean function that is represented as a sum of regression trees,
- $\epsilon\sim N(0, \sigma^2)$, and $\sigma^2$ are given inverse-Gamma($a_\sigma, b_\sigma$)
- $\sigma^2$ is updated between tree updates
- left parameters are given independent and identical Gaussian priors, $\mu \sim N(0, \tau)$, and $\tau$ is given inverse-Gamma($a_\tau, b_\tau$) prior
- $\tau$ is updated between sweeps
Prediction
Given $I$ iterations of the algorithm,
- the final $I-I_0$ samples are used to compute a point-wise average function evaluation
- $I_0 < I$ denotes the length of the burn-in period
recommend: $I = 40, I_0 = 15$