StableForestClassifier
StableForestClassifier
A model type for constructing a stable forest classifier, based on SIRUS.jl, and implementing the MLJ model interface.
From MLJ, the type can be imported using
StableForestClassifier = @load StableForestClassifier pkg=SIRUS
Do model = StableForestClassifier()
to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in StableForestClassifier(rng=...)
.
StableForestClassifier
implements the random forest classifier with a stabilized forest structure (Bénard et al., 2021). This stabilization increases stability when extracting rules. The impact on the predictive accuracy compared to standard random forests should be relatively small.
Just like normal random forests, this model is not easily explainable. If you are interested in an explainable model, use the StableRulesClassifier
or StableRulesRegressor
.
Training data
In MLJ or MLJBase, bind an instance model
to data with
mach = machine(model, X, y)
where
X
: any table of input features (eg, aDataFrame
) whose columns each have one of the following element scitypes:Continuous
,Count
, or<:OrderedFactor
; check column scitypes withschema(X)
y
: the target, which can be anyAbstractVector
whose element scitype is<:OrderedFactor
or<:Multiclass
; check the scitype withscitype(y)
Train the machine with fit!(mach, rows=...)
.
Hyperparameters
rng::AbstractRNG=default_rng()
: Random number generator. Using aStableRNG
fromStableRNGs.jl
is advised.partial_sampling::Float64=0.7
: Ratio of samples to use in each subset of the data. The default should be fine for most cases.n_trees::Int=1000
: The number of trees to use. It is advisable to use at least thousand trees to for a better rule selection, and in turn better predictive performance.max_depth::Int=2
: The depth of the tree. A lower depth decreases model complexity and can therefore improve accuracy when the sample size is small (reduce overfitting).q::Int=10
: Number of cutpoints to use per feature. The default value should be fine for most situations.min_data_in_leaf::Int=5
: Minimum number of data points per leaf.
Fitted parameters
The fields of fitted_params(mach)
are:
fitresult
: AStableForest
object.
Operations
predict(mach, Xnew)
: Return a vector of predictions for each row ofXnew
.