AdaBoostStumpClassifier
AdaBoostStumpClassifierA model type for constructing a Ada-boosted stump classifier, based on DecisionTree.jl, and implementing the MLJ model interface.
From MLJ, the type can be imported using
AdaBoostStumpClassifier = @load AdaBoostStumpClassifier pkg=DecisionTreeDo model = AdaBoostStumpClassifier() to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in AdaBoostStumpClassifier(n_iter=...).
Training data
In MLJ or MLJBase, bind an instance model to data with
mach = machine(model, X, y)where:
X: any table of input features (eg, aDataFrame) whose columns each have one of the following element scitypes:Continuous,Count, or<:OrderedFactor; check column scitypes withschema(X)y: the target, which can be anyAbstractVectorwhose element scitype is<:OrderedFactoror<:Multiclass; check the scitype withscitype(y)
Train the machine with fit!(mach, rows=...).
Hyperparameters
n_iter=10: number of iterations of AdaBoostfeature_importance: method to use for computing feature importances. One of(:impurity, :split)rng=Random.GLOBAL_RNG: random number generator or seed
Operations
predict(mach, Xnew): return predictions of the target given featuresXnewhaving the same scitype asXabove. Predictions are probabilistic, but uncalibrated.predict_mode(mach, Xnew): instead return the mode of each prediction above.
Fitted Parameters
The fields of fitted_params(mach) are:
stumps: theEnsembleobject returned by the core DecisionTree.jl algorithm.coefficients: the stump coefficients (one per stump)
Report
The fields of report(mach) are:
features: the names of the features encountered in training
Accessor functions
feature_importances(mach)returns a vector of(feature::Symbol => importance)pairs; the type of importance is determined by the hyperparameterfeature_importance(see above)
Examples
using MLJ
Booster = @load AdaBoostStumpClassifier pkg=DecisionTree
booster = Booster(n_iter=15)
X, y = @load_iris
mach = machine(booster, X, y) |> fit!
Xnew = (sepal_length = [6.4, 7.2, 7.4],
sepal_width = [2.8, 3.0, 2.8],
petal_length = [5.6, 5.8, 6.1],
petal_width = [2.1, 1.6, 1.9],)
yhat = predict(mach, Xnew) ## probabilistic predictions
predict_mode(mach, Xnew) ## point predictions
pdf.(yhat, "virginica") ## probabilities for the "verginica" class
fitted_params(mach).stumps ## raw `Ensemble` object from DecisionTree.jl
fitted_params(mach).coefs ## coefficient associated with each stump
feature_importances(mach)See also DecisionTree.jl and the unwrapped model type MLJDecisionTreeInterface.DecisionTree.AdaBoostStumpClassifier.