MondrianForests

Mondrian random forests in Julia

Introduction

This repository provides implementations of Mondrian random forests in Julia, based on methods detailed in Cattaneo, Klusowski and Underwood, 2023, arXiv:2310:09702. This package provides:

  • Fitting Mondrian random forests
  • Fitting debiased Mondrian random forests
  • Selecting a lifetime parameter with polynomial estimation or generalized cross-validation

Branches

The main branch contains stable versions. Other branches may be unstable, and are for development purposes only.

License

This repository and its included Julia package are licensed under GPLv3.

Julia package

The Julia package is named MondrianForests.jl

Installation

From the Julia General registry:

using Pkg
Pkg.add("MondrianForests")

Usage

using MondrianForests

# sample a two-dimensional Mondrian tree
d = 2
lambda = 2.0
tree = MondrianTree(d, lambda)
println()
show(tree)
println()

# generate some data
# covariates X_data are two-dimensional
# response Y_data is one-dimensional
# true regression function is zero
n_data = 100
data = MondrianForests.generate_uniform_data_uniform_errors(d, n_data)
X_data = data["X"]
Y_data = data["Y"]
println("covariates: ")
display(X_data[1:5])
println("\nresponses: ")
display(Y_data[1:5])

# select a lifetime parameter
# with generalized cross-validation
n_trees = 50
n_subsample = 30
debias_order = 0
lambdas = collect(range(0.5, 10.0, step=0.5))
lambda = select_lifetime_gcv(lambdas, n_trees, X_data, Y_data, debias_order, n_subsample)
println("\nlambda chosen by GCV: ", lambda)

# fit and evaluate a Mondrian random forest
x_evals = [(0.5, 0.5), (0.2, 0.8)]
estimate_var = true
forest = MondrianForest(lambda, n_trees, x_evals, X_data, Y_data, estimate_var)
println("\nestimated regression function:")
display(forest.mu_hat)
println("\nestimated estimator variance:")
display(forest.Sigma_hat)
println("\nestimated confidence band:")
display(forest.confidence_band)

# fit and evaluate a debiased Mondrian random forest
debiased_forest = DebiasedMondrianForest(lambda, n_trees, x_evals, debias_order,
                                         X_data, Y_data, estimate_var)
println("\ndebiased estimated regression function:")
display(debiased_forest.mu_hat)
println("\ndebiased estimated estimator variance:")
display(debiased_forest.Sigma_hat)
println("\ndebiased estimated confidence band:")
display(debiased_forest.confidence_band)

Dependencies

  • Distributions
  • Random
  • Suppressor
  • Test

Documentation

Documentation for the MondrianForests package is available on the web.