Tuning a model
To ensure code in this tutorial runs as shown, download the tutorial project folder and follow these instructions.If you have questions or suggestions about this tutorial, please open an issue here.
@OUTPUT (macro with 1 method)
In MLJ, tuning is implemented as a model wrapper. After wrapping a model in a tuning strategy (e.g. cross-validation) and binding the wrapped model to data in a machine, fitting the machine initiates a search for optimal model hyperparameters.
Let's use a decision tree classifier and tune the maximum depth of the tree. As usual, start by loading data and the model
using MLJ
using PrettyPrinting
X, y = @load_iris
DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree
import MLJDecisionTreeInterface ✔
To specify a range of value, you can use the range
dtc = DecisionTreeClassifier()
r = range(dtc, :max_depth, lower=1, upper=5)
NumericRange(1 ≤ max_depth ≤ 5; origin=3.0, unit=2.0)
As you can see, the range function takes a model (dtc
), a symbol for the hyperparameter of interest (:max_depth
) and indication of how to samples values. For hyperparameters of type <:Real
, you should specify a range of values as done above. For hyperparameters of other type (e.g. Symbol
), you should use the values=...
Once a range of values has been defined, you can then wrap the model in a TunedModel
specifying the tuning strategy.
tm = TunedModel(model=dtc, ranges=[r, ], measure=cross_entropy)
model = DecisionTreeClassifier(
max_depth = -1,
min_samples_leaf = 1,
min_samples_split = 2,
min_purity_increase = 0.0,
n_subfeatures = 0,
post_prune = false,
merge_purity_threshold = 1.0,
display_depth = 5,
feature_importance = :impurity,
rng = Random._GLOBAL_RNG()),
tuning = RandomSearch(
bounded = Distributions.Uniform,
positive_unbounded = Distributions.Gamma,
other = Distributions.Normal,
rng = Random._GLOBAL_RNG()),
resampling = Holdout(
fraction_train = 0.7,
shuffle = false,
rng = Random._GLOBAL_RNG()),
measure = LogLoss(tol = 2.22045e-16),
weights = nothing,
class_weights = nothing,
operation = nothing,
range = MLJBase.NumericRange{Int64, MLJBase.Bounded, Symbol}[NumericRange(1 ≤ max_depth ≤ 5; origin=3.0, unit=2.0)],
selection_heuristic = MLJTuning.NaiveSelection(nothing),
train_best = true,
repeats = 1,
n = nothing,
acceleration = ComputationalResources.CPU1{Nothing}(nothing),
acceleration_resampling = ComputationalResources.CPU1{Nothing}(nothing),
check_measure = true,
cache = true)
Note that "wrapping a model in a tuning strategy" as above means creating a new "self-tuning" version of the model, tuned_model = TunedModel(model=...)
, in which further key-word arguments specify:
the algorithm (a.k.a., tuning strategy) for searching the hyper-parameter space of the model (e.g.,
tuning = Random(rng=123)
ortuning = Grid(goal=100)
).the resampling strategy, used to evaluate performance for each value of the hyper-parameters (e.g.,
resampling=CV(nfolds=9, rng=123)
).the measure (or measures) on which to base performance evaluations (and for reporting purposes) (e.g.,
measure = rms
ormeasures = [rms, mae]
).the range, usually describing the "space" of hyperparameters to be searched (but more generally whatever extra information is required to complete the search specification, e.g., initial values in gradient-descent optimization).
For more options do ?TunedModel
To fit a tuned model, you can use the usual syntax:
m = machine(tm, X, y)
trained Machine; does not cache data
model: ProbabilisticTunedModel(model = DecisionTreeClassifier(max_depth = -1, …), …)
1: Source @632 ⏎ ScientificTypesBase.Table{AbstractVector{ScientificTypesBase.Continuous}}
2: Source @125 ⏎ AbstractVector{ScientificTypesBase.Multiclass{3}}
In order to inspect the best model, you can use the function fitted_params
on the machine and inspect the best_model
Note that here we have tuned a probabilistic model and consequently used a probabilistic measure for the tuning. We could also have decided we only cared about the mode and the misclassification rate, to do this, just use operation=predict_mode
in the tuned model:
tm = TunedModel(model=dtc, ranges=r, operation=predict_mode,
m = machine(tm, X, y)
Let's check the misclassification rate for the best model:
r = report(m)
Anyone wants plots? of course:
using Plots
plot(m, size=(800,600))
Let's generate simple dummy regression data
X = (x1=rand(100), x2=rand(100), x3=rand(100))
y = 2X.x1 - X.x2 + 0.05 * randn(100);
Let's then build a simple ensemble model with decision tree regressors:
DecisionTreeRegressor = @load DecisionTreeRegressor pkg=DecisionTree
forest = EnsembleModel(model=DecisionTreeRegressor())
import MLJDecisionTreeInterface ✔
model = DecisionTreeRegressor(
max_depth = -1,
min_samples_leaf = 5,
min_samples_split = 2,
min_purity_increase = 0.0,
n_subfeatures = 0,
post_prune = false,
merge_purity_threshold = 1.0,
feature_importance = :impurity,
rng = Random._GLOBAL_RNG()),
atomic_weights = Float64[],
bagging_fraction = 0.8,
rng = Random._GLOBAL_RNG(),
n = 100,
acceleration = ComputationalResources.CPU1{Nothing}(nothing),
out_of_bag_measure = Any[])
Such a model has nested hyperparameters in that the ensemble has hyperparameters (e.g. the :bagging_fraction
) and the atom has hyperparameters (e.g. :n_subfeatures
or :max_depth
). You can see this by inspecting the parameters using params
params(forest) |> pprint
(model = (max_depth = -1,
min_samples_leaf = 5,
min_samples_split = 2,
min_purity_increase = 0.0,
n_subfeatures = 0,
post_prune = false,
merge_purity_threshold = 1.0,
feature_importance = :impurity,
rng = Random._GLOBAL_RNG()),
atomic_weights = [],
bagging_fraction = 0.8,
rng = Random._GLOBAL_RNG(),
n = 100,
acceleration = ComputationalResources.CPU1{Nothing}(nothing),
out_of_bag_measure = [])
Range for nested hyperparameters are specified using dot syntax, the rest is done in much the same way as before:
r1 = range(forest, :(model.n_subfeatures), lower=1, upper=3)
r2 = range(forest, :bagging_fraction, lower=0.4, upper=1.0)
tm = TunedModel(model=forest, tuning=Grid(resolution=12),
resampling=CV(nfolds=6), ranges=[r1, r2],
m = machine(tm, X, y)
A useful function to inspect a model after fitting it is the report
function which collects information on the model and the tuning, for instance you can use it to recover the best measurement:
r = report(m)
Let's visualise this