StableForestRegressor

StableForestRegressor

A model type for constructing a stable forest regressor, based on SIRUS.jl, and implementing the MLJ model interface.

From MLJ, the type can be imported using

StableForestRegressor = @load StableForestRegressor pkg=SIRUS

Do model = StableForestRegressor() to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in StableForestRegressor(rng=...).

StableForestRegressor implements the random forest regressor with a stabilized forest structure (Bénard et al., 2021).

Training data

In MLJ or MLJBase, bind an instance model to data with

mach = machine(model, X, y)

where

  • X: any table of input features (eg, a DataFrame) whose columns each have one of the following element scitypes: Continuous, Count, or <:OrderedFactor; check column scitypes with schema(X)
  • y: the target, which can be any AbstractVector whose element scitype is <:OrderedFactor or <:Multiclass; check the scitype with scitype(y)

Train the machine with fit!(mach, rows=...).

Hyperparameters

  • rng::AbstractRNG=default_rng(): Random number generator. Using a StableRNG from StableRNGs.jl is advised.
  • partial_sampling::Float64=0.7: Ratio of samples to use in each subset of the data. The default should be fine for most cases.
  • n_trees::Int=1000: The number of trees to use. It is advisable to use at least thousand trees to for a better rule selection, and in turn better predictive performance.
  • max_depth::Int=2: The depth of the tree. A lower depth decreases model complexity and can therefore improve accuracy when the sample size is small (reduce overfitting).
  • q::Int=10: Number of cutpoints to use per feature. The default value should be fine for most situations.
  • min_data_in_leaf::Int=5: Minimum number of data points per leaf.

Fitted parameters

The fields of fitted_params(mach) are:

  • fitresult: A StableForest object.

Operations

  • predict(mach, Xnew): Return a vector of predictions for each row of Xnew.