RecursiveFeatureElimination
RecursiveFeatureElimination(model, n_features, step)
This model implements a recursive feature elimination algorithm for feature selection. It recursively removes features, training a base model on the remaining features and evaluating their importance until the desired number of features is selected.
Construct an instance with default hyper-parameters using the syntax rfe_model = RecursiveFeatureElimination(model=...)
. Provide keyword arguments to override hyper-parameter defaults.
Training data
In MLJ or MLJBase, bind an instance rfe_model
to data with
mach = machine(rfe_model, X, y)
OR, if the base model supports weights, as
mach = machine(rfe_model, X, y, w)
Here:
X
is any table of input features (eg, aDataFrame
) whose columns are of the scitype as that required by the base model; check column scitypes withschema(X)
and column scitypes required by base model withinput_scitype(basemodel)
.y
is the target, which can be any table of responses whose element scitype isContinuous
orFinite
depending on thetarget_scitype
required by the base model; check the scitype withscitype(y)
.w
is the observation weights which can either benothing
(default) or anAbstractVector
whoose element scitype isCount
orContinuous
. This is different fromweights
kernel which is an hyperparameter to the model, see below.
Train the machine using fit!(mach, rows=...)
.
Hyper-parameters
- model: A base model with a
fit
method that provides information on feature feature importance (i.ereports_feature_importances(model) == true
) - n_features::Real = 0: The number of features to select. If
0
, half of the features are selected. If a positive integer, the parameter is the absolute number of features to select. If a real number between 0 and 1, it is the fraction of features to select. - step::Real=1: If the value of step is at least 1, it signifies the quantity of features to eliminate in each iteration. Conversely, if step falls strictly within the range of 0.0 to 1.0, it denotes the proportion (rounded down) of features to remove during each iteration.
Operations
transform(mach, X)
: transform the input tableX
into a new table containing only
columns corresponding to features gotten from the RFE algorithm.
predict(mach, X)
: transform the input tableX
into a new table same as intransform(mach, X)
above and predict using the fitted base model on the transformed table.
Fitted parameters
The fields of fitted_params(mach)
are:
features_left
: names of features remaining after recursive feature elimination.model_fitresult
: fitted parameters of the base model.
Report
The fields of report(mach)
are:
scores
: dictionary of scores for each feature in the training dataset. The model deems highly scored variables more significant.model_report
: report for the fitted base model.
Examples
using FeatureSelection, MLJ, StableRNGs
RandomForestRegressor = @load RandomForestRegressor pkg=DecisionTree
## Creates a dataset where the target only depends on the first 5 columns of the input table.
A = rand(rng, 50, 10);
y = 10 .* sin.(
pi .* A[:, 1] .* A[:, 2]
) + 20 .* (A[:, 3] .- 0.5).^ 2 .+ 10 .* A[:, 4] .+ 5 * A[:, 5]);
X = MLJ.table(A);
## fit a rfe model
rf = RandomForestRegressor()
selector = RecursiveFeatureElimination(model = rf)
mach = machine(selector, X, y)
fit!(mach)
## view the feature importances
feature_importances(mach)
## predict using the base model
Xnew = MLJ.table(rand(rng, 50, 10));
predict(mach, Xnew)