DecisionTreeClassifier

mutable struct DecisionTreeClassifier <: MLJModelInterface.Probabilistic

A simple Decision Tree model for classification with support for Missing data, from the Beta Machine Learning Toolkit (BetaML).

Hyperparameters:

  • max_depth::Int64: The maximum depth the tree is allowed to reach. When this is reached the node is forced to become a leaf [def: 0, i.e. no limits]
  • min_gain::Float64: The minimum information gain to allow for a node's partition [def: 0]
  • min_records::Int64: The minimum number of records a node must holds to consider for a partition of it [def: 2]
  • max_features::Int64: The maximum number of (random) features to consider at each partitioning [def: 0, i.e. look at all features]
  • splitting_criterion::Function: This is the name of the function to be used to compute the information gain of a specific partition. This is done by measuring the difference betwwen the "impurity" of the labels of the parent node with those of the two child nodes, weighted by the respective number of items. [def: gini]. Either gini, entropy or a custom function. It can also be an anonymous function.
  • rng::Random.AbstractRNG: A Random Number Generator to be used in stochastic parts of the code [deafult: Random.GLOBAL_RNG]

Example:

julia> using MLJ

julia> X, y        = @load_iris;

julia> modelType   = @load DecisionTreeClassifier pkg = "BetaML" verbosity=0
BetaML.Trees.DecisionTreeClassifier

julia> model       = modelType()
DecisionTreeClassifier(
  max_depth = 0, 
  min_gain = 0.0, 
  min_records = 2, 
  max_features = 0, 
  splitting_criterion = BetaML.Utils.gini, 
  rng = Random._GLOBAL_RNG())

julia> mach        = machine(model, X, y);

julia> fit!(mach);
[ Info: Training machine(DecisionTreeClassifier(max_depth = 0, …), …).

julia> cat_est    = predict(mach, X)
150-element CategoricalDistributions.UnivariateFiniteVector{Multiclass{3}, String, UInt32, Float64}:
 UnivariateFinite{Multiclass{3}}(setosa=>1.0, versicolor=>0.0, virginica=>0.0)
 UnivariateFinite{Multiclass{3}}(setosa=>1.0, versicolor=>0.0, virginica=>0.0)
 ⋮
 UnivariateFinite{Multiclass{3}}(setosa=>0.0, versicolor=>0.0, virginica=>1.0)
 UnivariateFinite{Multiclass{3}}(setosa=>0.0, versicolor=>0.0, virginica=>1.0)
 UnivariateFinite{Multiclass{3}}(setosa=>0.0, versicolor=>0.0, virginica=>1.0)