IteratedModel
IteratedModel(model;
controls=MLJIteration.DEFAULT_CONTROLS,
resampling=Holdout(),
measure=nothing,
retrain=false,
advanced_options...,
)
Wrap the specified supervised model
in the specified iteration controls
. Here model
should support iteration, which is true if (iteration_parameter(model)
is different from nothing
.
Available controls: Step(), Info(), Warn(), Error(), Callback(), WithLossDo(), WithTrainingLossesDo(), WithNumberDo(), Data(), Disjunction(), GL(), InvalidValue(), Never(), NotANumber(), NumberLimit(), NumberSinceBest(), PQ(), Patience(), Threshold(), TimeLimit(), Warmup(), WithIterationsDo(), WithEvaluationDo(), WithFittedParamsDo(), WithReportDo(), WithMachineDo(), WithModelDo(), CycleLearningRate() and Save().
To make out-of-sample losses available to the controls, the wrapped model
is only trained on part of the data, as iteration proceeds. The user may want to force retraining on all data after controlled iteration has finished by specifying retrain=true
. See also "Training", and the retrain
option, under "Extended help" below.
Extended help
Options
controls=Any[IterationControl.Step(1), EarlyStopping.Patience(5), EarlyStopping.GL(2.0), EarlyStopping.TimeLimit(Dates.Millisecond(108000)), EarlyStopping.InvalidValue()]
: Controls are summarized at https://JuliaAI.github.io/MLJ.jl/dev/getting_started/ but query individual doc-strings for details and advanced options. For creating your own controls, refer to the documentation just cited.resampling=Holdout(fraction_train=0.7)
: The default resampling holds back 30% of data for computing an out-of-sample estimate of performance (the "loss") for loss-based controls such asWithLossDo
. Specifyresampling=nothing
if all data is to be used for controlled iteration, with each out-of-sample loss replaced by the most recent training loss, assuming this is made available by the model (supports_training_losses(model) == true
). If the model does not report a training loss, you can useresampling=InSample()
instead. Otherwise,resampling
must have typeHoldout
or be a vector with one element of the form(train_indices, test_indices)
.measure=nothing
: StatisticalMeasures.jl compatible measure for estimating model performance (the "loss", but the orientation is immaterial - i.e., this could be a score). Inferred by default. Ignored ifresampling=nothing
.retrain=false
: Ifretrain=true
orresampling=nothing
,iterated_model
behaves exactly like the originalmodel
but with the iteration parameter automatically selected ("learned"). That is, the model is retrained on all available data, using the same number of iterations, once controlled iteration has stopped. This is typically desired if wrapping the iterated model further, or when inserting in a pipeline or other composite model. Ifretrain=false
(default) andresampling isa Holdout
, theniterated_model
behaves like the original model trained on a subset of the provided data.weights=nothing
: per-observation weights to be passed tomeasure
where supported; if unspecified, these are understood to be uniform.class_weights=nothing
: class-weights to be passed tomeasure
where supported; if unspecified, these are understood to be uniform.operation=nothing
: Operation, such aspredict
orpredict_mode
, for computing target values, or proxy target values, for consumption bymeasure
; automatically inferred by default.check_measure=true
: Specifyfalse
to override checks onmeasure
for compatibility with the training data.iteration_parameter=nothing
: A symbol, such as:epochs
, naming the iteration parameter ofmodel
; inferred by default. Note that the actual value of the iteration parameter in the suppliedmodel
is ignored; only the value of an internal clone is mutated during training the wrapped model.cache=true
: Whether or not model-specific representations of data are cached in between iteration parameter increments; specifycache=false
to prioritize memory over speed.
Training
Training an instance iterated_model
of IteratedModel
on some data
(by binding to a machine and calling fit!
, for example) performs the following actions:
- Assuming
resampling !== nothing
, thedata
is split into train and test sets, according to the specifiedresampling
strategy. - A clone of the wrapped model,
model
is bound to the train data in an internal machine,train_mach
. Ifresampling === nothing
, all data is used instead. This machine is the object to which controls are applied. For example,Callback(fitted_params |> print)
will print the value offitted_params(train_mach)
. - The iteration parameter of the clone is set to
0
. - The specified
controls
are repeatedly applied totrain_mach
in sequence, until one of the controls triggers a stop. Loss-based controls (eg,Patience()
,GL()
,Threshold(0.001)
) use an out-of-sample loss, obtained by applyingmeasure
to predictions and the test target values. (Specifically, these predictions are those returned byoperation(train_mach)
.) Ifresampling === nothing
then the most recent training loss is used instead. Some controls require both out-of-sample and training losses (eg,PQ()
). - Once a stop has been triggered, a clone of
model
is bound to alldata
in a machine calledmach_production
below, unlessretrain == false
(true by default) orresampling === nothing
, in which casemach_production
coincides withtrain_mach
.
Prediction
Calling predict(mach, Xnew)
in the example above returns predict(mach_production, Xnew)
. Similar similar statements hold for predict_mean
, predict_mode
, predict_median
.
Controls that mutate parameters
A control is permitted to mutate the fields (hyper-parameters) of train_mach.model
(the clone of model
). For example, to mutate a learning rate one might use the control
Callback(mach -> mach.model.eta = 1.05*mach.model.eta)
However, unless model
supports warm restarts with respect to changes in that parameter, this will trigger retraining of train_mach
from scratch, with a different training outcome, which is not recommended.
Warm restarts
In the following example, the second fit!
call will not restart training of the internal train_mach
, assuming model
supports warm restarts:
iterated_model = IteratedModel(
model,
controls = [Step(1), NumberLimit(100)],
)
mach = machine(iterated_model, X, y)
fit!(mach) ## train for 100 iterations
iterated_model.controls = [Step(1), NumberLimit(50)],
fit!(mach) ## train for an *extra* 50 iterations
More generally, if iterated_model
is mutated and fit!(mach)
is called again, then a warm restart is attempted if the only parameters to change are model
or controls
or both.
Specifically, train_mach.model
is mutated to match the current value of iterated_model.model
and the iteration parameter of the latter is updated to the last value used in the preceding fit!(mach)
call. Then repeated application of the (updated) controls begin anew.