StableForestRegressor
StableForestRegressor
A model type for constructing a stable forest regressor, based on SIRUS.jl, and implementing the MLJ model interface.
From MLJ, the type can be imported using
StableForestRegressor = @load StableForestRegressor pkg=SIRUS
Do model = StableForestRegressor()
to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in StableForestRegressor(rng=...)
.
StableForestRegressor
implements the random forest regressor with a stabilized forest structure (Bénard et al., 2021).
Training data
In MLJ or MLJBase, bind an instance model
to data with
mach = machine(model, X, y)
where
X
: any table of input features (eg, aDataFrame
) whose columns each have one of the following element scitypes:Continuous
,Count
, or<:OrderedFactor
; check column scitypes withschema(X)
y
: the target, which can be anyAbstractVector
whose element scitype is<:OrderedFactor
or<:Multiclass
; check the scitype withscitype(y)
Train the machine with fit!(mach, rows=...)
.
Hyperparameters
rng::AbstractRNG=default_rng()
: Random number generator. Using aStableRNG
fromStableRNGs.jl
is advised.partial_sampling::Float64=0.7
: Ratio of samples to use in each subset of the data. The default should be fine for most cases.n_trees::Int=1000
: The number of trees to use. It is advisable to use at least thousand trees to for a better rule selection, and in turn better predictive performance.max_depth::Int=2
: The depth of the tree. A lower depth decreases model complexity and can therefore improve accuracy when the sample size is small (reduce overfitting).q::Int=10
: Number of cutpoints to use per feature. The default value should be fine for most situations.min_data_in_leaf::Int=5
: Minimum number of data points per leaf.
Fitted parameters
The fields of fitted_params(mach)
are:
fitresult
: AStableForest
object.
Operations
predict(mach, Xnew)
: Return a vector of predictions for each row ofXnew
.