DecisionTreeClassifier
DecisionTreeClassifier
A model type for constructing a CART decision tree classifier, based on DecisionTree.jl, and implementing the MLJ model interface.
From MLJ, the type can be imported using
DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree
Do model = DecisionTreeClassifier()
to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in DecisionTreeClassifier(max_depth=...)
.
DecisionTreeClassifier
implements the CART algorithm, originally published in Breiman, Leo; Friedman, J. H.; Olshen, R. A.; Stone, C. J. (1984): "Classification and regression trees". Monterey, CA: Wadsworth & Brooks/Cole Advanced Books & Software..
Training data
In MLJ or MLJBase, bind an instance model
to data with
mach = machine(model, X, y)
where
X
: any table of input features (eg, aDataFrame
) whose columns each have one of the following element scitypes:Continuous
,Count
, or<:OrderedFactor
; check column scitypes withschema(X)
y
: is the target, which can be anyAbstractVector
whose element scitype is<:OrderedFactor
or<:Multiclass
; check the scitype withscitype(y)
Train the machine using fit!(mach, rows=...)
.
Hyperparameters
max_depth=-1
: max depth of the decision tree (-1=any)min_samples_leaf=1
: max number of samples each leaf needs to havemin_samples_split=2
: min number of samples needed for a splitmin_purity_increase=0
: min purity needed for a splitn_subfeatures=0
: number of features to select at random (0 for all)post_prune=false
: set totrue
for post-fit pruningmerge_purity_threshold=1.0
: (post-pruning) merge leaves having combined purity>= merge_purity_threshold
display_depth=5
: max depth to show when displaying the treefeature_importance
: method to use for computing feature importances. One of(:impurity, :split)
rng=Random.GLOBAL_RNG
: random number generator or seed
Operations
predict(mach, Xnew)
: return predictions of the target given featuresXnew
having the same scitype asX
above. Predictions are probabilistic, but uncalibrated.predict_mode(mach, Xnew)
: instead return the mode of each prediction above.
Fitted parameters
The fields of fitted_params(mach)
are:
raw_tree
: the rawNode
,Leaf
orRoot
object returned by the core DecisionTree.jl algorithmtree
: a visualizable, wrapped version ofraw_tree
implementing the AbstractTrees.jl interface; see "Examples" belowencoding
: dictionary of target classes keyed on integers used internally by DecisionTree.jlfeatures
: the names of the features encountered in training, in an order consistent with the output ofprint_tree
(see below)
Report
The fields of report(mach)
are:
classes_seen
: list of target classes actually observed in trainingprint_tree
: alternative method to print the fitted tree, with single argument the tree depth; interpretation requires internal integer-class encoding (see "Fitted parameters" above).features
: the names of the features encountered in training, in an order consistent with the output ofprint_tree
(see below)
Accessor functions
feature_importances(mach)
returns a vector of(feature::Symbol => importance)
pairs; the type of importance is determined by the hyperparameterfeature_importance
(see above)
Examples
using MLJ
DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree
model = DecisionTreeClassifier(max_depth=3, min_samples_split=3)
X, y = @load_iris
mach = machine(model, X, y) |> fit!
Xnew = (sepal_length = [6.4, 7.2, 7.4],
sepal_width = [2.8, 3.0, 2.8],
petal_length = [5.6, 5.8, 6.1],
petal_width = [2.1, 1.6, 1.9],)
yhat = predict(mach, Xnew) ## probabilistic predictions
predict_mode(mach, Xnew) ## point predictions
pdf.(yhat, "virginica") ## probabilities for the "verginica" class
julia> tree = fitted_params(mach).tree
petal_length < 2.45
├─ setosa (50/50)
└─ petal_width < 1.75
├─ petal_length < 4.95
│ ├─ versicolor (47/48)
│ └─ virginica (4/6)
└─ petal_length < 4.85
├─ virginica (2/3)
└─ virginica (43/43)
using Plots, TreeRecipe
plot(tree) ## for a graphical representation of the tree
feature_importances(mach)
See also DecisionTree.jl and the unwrapped model type MLJDecisionTreeInterface.DecisionTree.DecisionTreeClassifier
.