StableRulesRegressor
StableRulesRegressor
A model type for constructing a stable rules regressor, based on SIRUS.jl, and implementing the MLJ model interface.
From MLJ, the type can be imported using
StableRulesRegressor = @load StableRulesRegressor pkg=SIRUS
Do model = StableRulesRegressor()
to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in StableRulesRegressor(rng=...)
.
StableRulesRegressor
implements the explainable rule-based regression model based on a random forest.
Training data
In MLJ or MLJBase, bind an instance model
to data with
mach = machine(model, X, y)
where
X
: any table of input features (eg, aDataFrame
) whose columns each have one of the following element scitypes:Continuous
,Count
, or<:OrderedFactor
; check column scitypes withschema(X)
y
: the target, which can be anyAbstractVector
whose element scitype is<:OrderedFactor
or<:Multiclass
; check the scitype withscitype(y)
Train the machine with fit!(mach, rows=...)
.
Hyperparameters
rng::AbstractRNG=default_rng()
: Random number generator. Using aStableRNG
fromStableRNGs.jl
is advised.partial_sampling::Float64=0.7
: Ratio of samples to use in each subset of the data. The default should be fine for most cases.n_trees::Int=1000
: The number of trees to use. It is advisable to use at least thousand trees to for a better rule selection, and in turn better predictive performance.max_depth::Int=2
: The depth of the tree. A lower depth decreases model complexity and can therefore improve accuracy when the sample size is small (reduce overfitting).q::Int=10
: Number of cutpoints to use per feature. The default value should be fine for most situations.min_data_in_leaf::Int=5
: Minimum number of data points per leaf.max_rules::Int=10
: This is the most important hyperparameter afterlambda
. The more rules, the more accurate the model should be. If this is not the case, tunelambda
first. However, more rules will also decrease model interpretability. So, it is important to find a good balance here. In most cases, 10 to 40 rules should provide reasonable accuracy while remaining interpretable.lambda::Float64=1.0
: The weights of the final rules are determined via a regularized regression over each rule as a binary feature. This hyperparameter specifies the strength of the ridge (L2) regularizer. SIRUS is very sensitive to the choice of this hyperparameter. Ensure that you try the full range from 10^-4 to 10^4 (e.g., 0.001, 0.01, ..., 100). When trying the range, one good check is to verify that an increase inmax_rules
increases performance. If this is not the case, then try a different value forlambda
.
Fitted parameters
The fields of fitted_params(mach)
are:
fitresult
: AStableRules
object.
Operations
predict(mach, Xnew)
: Return a vector of predictions for each row ofXnew
.