StableForestClassifier

StableForestClassifier

A model type for constructing a stable forest classifier, based on SIRUS.jl, and implementing the MLJ model interface.

From MLJ, the type can be imported using

StableForestClassifier = @load StableForestClassifier pkg=SIRUS

Do model = StableForestClassifier() to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in StableForestClassifier(rng=...).

StableForestClassifier implements the random forest classifier with a stabilized forest structure (Bénard et al., 2021). This stabilization increases stability when extracting rules. The impact on the predictive accuracy compared to standard random forests should be relatively small.

Note

Just like normal random forests, this model is not easily explainable. If you are interested in an explainable model, use the StableRulesClassifier or StableRulesRegressor.

Training data

In MLJ or MLJBase, bind an instance model to data with

mach = machine(model, X, y)

where

  • X: any table of input features (eg, a DataFrame) whose columns each have one of the following element scitypes: Continuous, Count, or <:OrderedFactor; check column scitypes with schema(X)
  • y: the target, which can be any AbstractVector whose element scitype is <:OrderedFactor or <:Multiclass; check the scitype with scitype(y)

Train the machine with fit!(mach, rows=...).

Hyperparameters

  • rng::AbstractRNG=default_rng(): Random number generator. Using a StableRNG from StableRNGs.jl is advised.
  • partial_sampling::Float64=0.7: Ratio of samples to use in each subset of the data. The default should be fine for most cases.
  • n_trees::Int=1000: The number of trees to use. It is advisable to use at least thousand trees to for a better rule selection, and in turn better predictive performance.
  • max_depth::Int=2: The depth of the tree. A lower depth decreases model complexity and can therefore improve accuracy when the sample size is small (reduce overfitting).
  • q::Int=10: Number of cutpoints to use per feature. The default value should be fine for most situations.
  • min_data_in_leaf::Int=5: Minimum number of data points per leaf.

Fitted parameters

The fields of fitted_params(mach) are:

  • fitresult: A StableForest object.

Operations

  • predict(mach, Xnew): Return a vector of predictions for each row of Xnew.