DecisionTreeClassifier

DecisionTreeClassifier

A model type for constructing a CART decision tree classifier, based on DecisionTree.jl, and implementing the MLJ model interface.

From MLJ, the type can be imported using

DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree

Do model = DecisionTreeClassifier() to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in DecisionTreeClassifier(max_depth=...).

DecisionTreeClassifier implements the CART algorithm, originally published in Breiman, Leo; Friedman, J. H.; Olshen, R. A.; Stone, C. J. (1984): "Classification and regression trees". Monterey, CA: Wadsworth & Brooks/Cole Advanced Books & Software..

Training data

In MLJ or MLJBase, bind an instance model to data with

mach = machine(model, X, y)

where

  • X: any table of input features (eg, a DataFrame) whose columns each have one of the following element scitypes: Continuous, Count, or <:OrderedFactor; check column scitypes with schema(X)
  • y: is the target, which can be any AbstractVector whose element scitype is <:OrderedFactor or <:Multiclass; check the scitype with scitype(y)

Train the machine using fit!(mach, rows=...).

Hyperparameters

  • max_depth=-1: max depth of the decision tree (-1=any)
  • min_samples_leaf=1: max number of samples each leaf needs to have
  • min_samples_split=2: min number of samples needed for a split
  • min_purity_increase=0: min purity needed for a split
  • n_subfeatures=0: number of features to select at random (0 for all)
  • post_prune=false: set to true for post-fit pruning
  • merge_purity_threshold=1.0: (post-pruning) merge leaves having combined purity >= merge_purity_threshold
  • display_depth=5: max depth to show when displaying the tree
  • feature_importance: method to use for computing feature importances. One of (:impurity, :split)
  • rng=Random.GLOBAL_RNG: random number generator or seed

Operations

  • predict(mach, Xnew): return predictions of the target given features Xnew having the same scitype as X above. Predictions are probabilistic, but uncalibrated.
  • predict_mode(mach, Xnew): instead return the mode of each prediction above.

Fitted parameters

The fields of fitted_params(mach) are:

  • raw_tree: the raw Node, Leaf or Root object returned by the core DecisionTree.jl algorithm
  • tree: a visualizable, wrapped version of raw_tree implementing the AbstractTrees.jl interface; see "Examples" below
  • encoding: dictionary of target classes keyed on integers used internally by DecisionTree.jl
  • features: the names of the features encountered in training, in an order consistent with the output of print_tree (see below)

Report

The fields of report(mach) are:

  • classes_seen: list of target classes actually observed in training
  • print_tree: alternative method to print the fitted tree, with single argument the tree depth; interpretation requires internal integer-class encoding (see "Fitted parameters" above).
  • features: the names of the features encountered in training, in an order consistent with the output of print_tree (see below)

Accessor functions

  • feature_importances(mach) returns a vector of (feature::Symbol => importance) pairs; the type of importance is determined by the hyperparameter feature_importance (see above)

Examples

using MLJ
DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree
model = DecisionTreeClassifier(max_depth=3, min_samples_split=3)

X, y = @load_iris
mach = machine(model, X, y) |> fit!

Xnew = (sepal_length = [6.4, 7.2, 7.4],
        sepal_width = [2.8, 3.0, 2.8],
        petal_length = [5.6, 5.8, 6.1],
        petal_width = [2.1, 1.6, 1.9],)
yhat = predict(mach, Xnew) ## probabilistic predictions
predict_mode(mach, Xnew)   ## point predictions
pdf.(yhat, "virginica")    ## probabilities for the "verginica" class

julia> tree = fitted_params(mach).tree
petal_length < 2.45
├─ setosa (50/50)
└─ petal_width < 1.75
   ├─ petal_length < 4.95
   │  ├─ versicolor (47/48)
   │  └─ virginica (4/6)
   └─ petal_length < 4.85
      ├─ virginica (2/3)
      └─ virginica (43/43)

using Plots, TreeRecipe
plot(tree) ## for a graphical representation of the tree

feature_importances(mach)

See also DecisionTree.jl and the unwrapped model type MLJDecisionTreeInterface.DecisionTree.DecisionTreeClassifier.