DecisionTreeRegressor
DecisionTreeRegressor
A model type for constructing a CART decision tree regressor, based on DecisionTree.jl, and implementing the MLJ model interface.
From MLJ, the type can be imported using
DecisionTreeRegressor = @load DecisionTreeRegressor pkg=DecisionTree
Do model = DecisionTreeRegressor()
to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in DecisionTreeRegressor(max_depth=...)
.
DecisionTreeRegressor
implements the CART algorithm, originally published in Breiman, Leo; Friedman, J. H.; Olshen, R. A.; Stone, C. J. (1984): "Classification and regression trees". Monterey, CA: Wadsworth & Brooks/Cole Advanced Books & Software..
Training data
In MLJ or MLJBase, bind an instance model
to data with
mach = machine(model, X, y)
where
X
: any table of input features (eg, aDataFrame
) whose columns each have one of the following element scitypes:Continuous
,Count
, or<:OrderedFactor
; check column scitypes withschema(X)
y
: the target, which can be anyAbstractVector
whose element scitype isContinuous
; check the scitype withscitype(y)
Train the machine with fit!(mach, rows=...)
.
Hyperparameters
max_depth=-1
: max depth of the decision tree (-1=any)min_samples_leaf=1
: max number of samples each leaf needs to havemin_samples_split=2
: min number of samples needed for a splitmin_purity_increase=0
: min purity needed for a splitn_subfeatures=0
: number of features to select at random (0 for all)post_prune=false
: set totrue
for post-fit pruningmerge_purity_threshold=1.0
: (post-pruning) merge leaves having combined purity>= merge_purity_threshold
feature_importance
: method to use for computing feature importances. One of(:impurity, :split)
rng=Random.GLOBAL_RNG
: random number generator or seed
Operations
predict(mach, Xnew)
: return predictions of the target given new featuresXnew
having the same scitype asX
above.
Fitted parameters
The fields of fitted_params(mach)
are:
tree
: the tree or stump object returned by the core DecisionTree.jl algorithmfeatures
: the names of the features encountered in training
Report
The fields of report(mach)
are:
features
: the names of the features encountered in training
Accessor functions
feature_importances(mach)
returns a vector of(feature::Symbol => importance)
pairs; the type of importance is determined by the hyperparameterfeature_importance
(see above)
Examples
using MLJ
DecisionTreeRegressor = @load DecisionTreeRegressor pkg=DecisionTree
model = DecisionTreeRegressor(max_depth=3, min_samples_split=3)
X, y = make_regression(100, 4; rng=123) ## synthetic data
mach = machine(model, X, y) |> fit!
Xnew, _ = make_regression(3, 2; rng=123)
yhat = predict(mach, Xnew) ## new predictions
julia> fitted_params(mach).tree
x1 < 0.2758
├─ x2 < 0.9137
│ ├─ x1 < -0.9582
│ │ ├─ 0.9189256882087312 (0/12)
│ │ └─ -0.23180616021065256 (0/38)
│ └─ -1.6461153800037722 (0/9)
└─ x1 < 1.062
├─ x2 < -0.4969
│ ├─ -0.9330755147107384 (0/5)
│ └─ -2.3287967825015548 (0/17)
└─ x2 < 0.4598
├─ -2.931299926506291 (0/11)
└─ -4.726518740473489 (0/8)
feature_importances(mach) ## get feature importances
See also DecisionTree.jl and the unwrapped model type MLJDecisionTreeInterface.DecisionTree.DecisionTreeRegressor
.