StableRulesClassifier

StableRulesClassifier

A model type for constructing a stable rules classifier, based on SIRUS.jl, and implementing the MLJ model interface.

From MLJ, the type can be imported using

StableRulesClassifier = @load StableRulesClassifier pkg=SIRUS

Do model = StableRulesClassifier() to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in StableRulesClassifier(rng=...).

StableRulesClassifier implements the explainable rule-based model based on a random forest.

Training data

In MLJ or MLJBase, bind an instance model to data with

mach = machine(model, X, y)

where

  • X: any table of input features (eg, a DataFrame) whose columns each have one of the following element scitypes: Continuous, Count, or <:OrderedFactor; check column scitypes with schema(X)
  • y: the target, which can be any AbstractVector whose element scitype is <:OrderedFactor or <:Multiclass; check the scitype with scitype(y)

Train the machine with fit!(mach, rows=...).

Hyperparameters

  • rng::AbstractRNG=default_rng(): Random number generator. Using a StableRNG from StableRNGs.jl is advised.
  • partial_sampling::Float64=0.7: Ratio of samples to use in each subset of the data. The default should be fine for most cases.
  • n_trees::Int=1000: The number of trees to use. It is advisable to use at least thousand trees to for a better rule selection, and in turn better predictive performance.
  • max_depth::Int=2: The depth of the tree. A lower depth decreases model complexity and can therefore improve accuracy when the sample size is small (reduce overfitting).
  • q::Int=10: Number of cutpoints to use per feature. The default value should be fine for most situations.
  • min_data_in_leaf::Int=5: Minimum number of data points per leaf.
  • max_rules::Int=10: This is the most important hyperparameter after lambda. The more rules, the more accurate the model should be. If this is not the case, tune lambda first. However, more rules will also decrease model interpretability. So, it is important to find a good balance here. In most cases, 10 to 40 rules should provide reasonable accuracy while remaining interpretable.
  • lambda::Float64=1.0: The weights of the final rules are determined via a regularized regression over each rule as a binary feature. This hyperparameter specifies the strength of the ridge (L2) regularizer. SIRUS is very sensitive to the choice of this hyperparameter. Ensure that you try the full range from 10^-4 to 10^4 (e.g., 0.001, 0.01, ..., 100). When trying the range, one good check is to verify that an increase in max_rules increases performance. If this is not the case, then try a different value for lambda.

Fitted parameters

The fields of fitted_params(mach) are:

  • fitresult: A StableRules object.

Operations

  • predict(mach, Xnew): Return a vector of predictions for each row of Xnew.