DecisionTreeRegressor

DecisionTreeRegressor

A model type for constructing a CART decision tree regressor, based on DecisionTree.jl, and implementing the MLJ model interface.

From MLJ, the type can be imported using

DecisionTreeRegressor = @load DecisionTreeRegressor pkg=DecisionTree

Do model = DecisionTreeRegressor() to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in DecisionTreeRegressor(max_depth=...).

DecisionTreeRegressor implements the CART algorithm, originally published in Breiman, Leo; Friedman, J. H.; Olshen, R. A.; Stone, C. J. (1984): "Classification and regression trees". Monterey, CA: Wadsworth & Brooks/Cole Advanced Books & Software..

Training data

In MLJ or MLJBase, bind an instance model to data with

mach = machine(model, X, y)

where

  • X: any table of input features (eg, a DataFrame) whose columns each have one of the following element scitypes: Continuous, Count, or <:OrderedFactor; check column scitypes with schema(X)
  • y: the target, which can be any AbstractVector whose element scitype is Continuous; check the scitype with scitype(y)

Train the machine with fit!(mach, rows=...).

Hyperparameters

  • max_depth=-1: max depth of the decision tree (-1=any)
  • min_samples_leaf=1: max number of samples each leaf needs to have
  • min_samples_split=2: min number of samples needed for a split
  • min_purity_increase=0: min purity needed for a split
  • n_subfeatures=0: number of features to select at random (0 for all)
  • post_prune=false: set to true for post-fit pruning
  • merge_purity_threshold=1.0: (post-pruning) merge leaves having combined purity >= merge_purity_threshold
  • feature_importance: method to use for computing feature importances. One of (:impurity, :split)
  • rng=Random.GLOBAL_RNG: random number generator or seed

Operations

  • predict(mach, Xnew): return predictions of the target given new features Xnew having the same scitype as X above.

Fitted parameters

The fields of fitted_params(mach) are:

  • tree: the tree or stump object returned by the core DecisionTree.jl algorithm
  • features: the names of the features encountered in training

Report

The fields of report(mach) are:

  • features: the names of the features encountered in training

Accessor functions

  • feature_importances(mach) returns a vector of (feature::Symbol => importance) pairs; the type of importance is determined by the hyperparameter feature_importance (see above)

Examples

using MLJ
DecisionTreeRegressor = @load DecisionTreeRegressor pkg=DecisionTree
model = DecisionTreeRegressor(max_depth=3, min_samples_split=3)

X, y = make_regression(100, 4; rng=123) ## synthetic data
mach = machine(model, X, y) |> fit!

Xnew, _ = make_regression(3, 2; rng=123)
yhat = predict(mach, Xnew) ## new predictions

julia> fitted_params(mach).tree
x1 < 0.2758
├─ x2 < 0.9137
│  ├─ x1 < -0.9582
│  │  ├─ 0.9189256882087312 (0/12)
│  │  └─ -0.23180616021065256 (0/38)
│  └─ -1.6461153800037722 (0/9)
└─ x1 < 1.062
   ├─ x2 < -0.4969
   │  ├─ -0.9330755147107384 (0/5)
   │  └─ -2.3287967825015548 (0/17)
   └─ x2 < 0.4598
      ├─ -2.931299926506291 (0/11)
      └─ -4.726518740473489 (0/8)

feature_importances(mach) ## get feature importances

See also DecisionTree.jl and the unwrapped model type MLJDecisionTreeInterface.DecisionTree.DecisionTreeRegressor.