Controlling Iterative Models
Iterative supervised machine learning models are usually trained until an out-of-sample estimate of the performance satisfies some stopping criterion, such as k
consecutive deteriorations of the performance (see Patience
below). A more sophisticated kind of control might dynamically mutate parameters, such as a learning rate, in response to the behavior of these estimates.
Some iterative model implementations enable some form of automated control, with the method and options for doing so varying from model to model. But sometimes it is up to the user to arrange control, which in the crudest case reduces to manually experimenting with the iteration parameter.
In response to this ad hoc state of affairs, MLJ provides a uniform and feature-rich interface for controlling any iterative model that exposes its iteration parameter as a hyper-parameter, and which implements the "warm restart" behavior described in Machines.
Basic use
As in Tuning Models, iteration control in MLJ is implemented as a model wrapper, which allows composition with other meta-algorithms. Ordinarily, the wrapped model behaves just like the original model, but with the training occurring on a subset of the provided data (to allow computation of an out-of-sample loss) and with the iteration parameter automatically determined by the controls specified in the wrapper.
By setting retrain=true
one can ask that the wrapped model retrain on all supplied data, after learning the appropriate number of iterations from the controlled training phase:
using MLJ
X, y = make_moons(100, rng=123, noise=0.5)
EvoTreeClassifier = @load EvoTreeClassifier verbosity=0
iterated_model = IteratedModel(model=EvoTreeClassifier(rng=123, eta=0.005),
resampling=Holdout(),
measures=log_loss,
controls=[Step(5),
Patience(2),
NumberLimit(100)],
retrain=true)
mach = machine(iterated_model, X, y)
julia> fit!(mach)
┌ Info: Training machine(ProbabilisticIteratedModel(model = EvoTrees.EvoTreeClassifier{EvoTrees.MLogLoss} │ - nrounds: 100 │ - L2: 0.0 │ - lambda: 0.0 │ - gamma: 0.0 │ - eta: 0.005 │ - max_depth: 6 │ - min_weight: 1.0 │ - rowsample: 1.0 │ - colsample: 1.0 │ - nbins: 64 │ - alpha: 0.5 │ - tree_type: binary │ - rng: Random.MersenneTwister(123) └ , …), …). [ Info: No iteration parameter specified. Using `iteration_parameter=:(nrounds)`. [ Info: final loss: 0.46683584745719836 [ Info: Stop triggered by Patience(2) stopping criterion. [ Info: Retraining on all provided data. To suppress, specify `retrain=false`. [ Info: Total of 215 iterations. trained Machine; does not cache data model: ProbabilisticIteratedModel(model = EvoTrees.EvoTreeClassifier{EvoTrees.MLogLoss} - nrounds: 100 - L2: 0.0 - lambda: 0.0 - gamma: 0.0 - eta: 0.005 - max_depth: 6 - min_weight: 1.0 - rowsample: 1.0 - colsample: 1.0 - nbins: 64 - alpha: 0.5 - tree_type: binary - rng: Random.MersenneTwister(123) , …) args: 1: Source @547 ⏎ Table{AbstractVector{Continuous}} 2: Source @513 ⏎ AbstractVector{Multiclass{2}}
As detailed under IteratedModel
below, the specified controls
are repeatedly applied in sequence to a training machine, constructed under the hood, until one of the controls triggers a stop. Here Step(5)
means "Compute 5 more iterations" (in this case starting from none); Patience(2)
means "Stop at the end of the control cycle if there have been 2 consecutive drops in the log loss"; and NumberLimit(100)
is a safeguard ensuring a stop after 100 control cycles (500 iterations). See Controls provided below for a complete list.
Because iteration is implemented as a wrapper, the "self-iterating" model can be evaluated using cross-validation, say, and the number of iterations on each fold will generally be different:
e = evaluate!(mach, resampling=CV(nfolds=3), measure=log_loss, verbosity=0);
map(e.report_per_fold) do r
r.n_iterations
end
3-element Vector{Int64}:
340
150
500
Alternatively, one might wrap the self-iterating model in a tuning strategy, using TunedModel
; see Tuning Models. In this way, the optimization of some other hyper-parameter is realized simultaneously with that of the iteration parameter, which will frequently be more efficient than a direct two-parameter search.
Controls provided
In the table below, mach
is the training machine being iterated, constructed by binding the supplied data to the model
specified in the IteratedModel
wrapper, but trained in each iteration on a subset of the data, according to the value of the resampling
hyper-parameter of the wrapper (using all data if resampling=nothing
).
control | description | can trigger a stop |
---|---|---|
Step (n=1) | Train model for n more iterations | no |
TimeLimit (t=0.5) | Stop after t hours | yes |
NumberLimit (n=100) | Stop after n applications of the control | yes |
NumberSinceBest (n=6) | Stop when best loss occurred n control applications ago | yes |
InvalidValue () | Stop when NaN , Inf or -Inf loss/training loss encountered | yes |
Threshold (value=0.0) | Stop when loss < value | yes |
GL (alpha=2.0) | † Stop after the "generalization loss (GL)" exceeds alpha | yes |
PQ (alpha=0.75, k=5) | † Stop after "progress-modified GL" exceeds alpha | yes |
Patience (n=5) | † Stop after n consecutive loss increases | yes |
Warmup (c; n=1) | Wait for n loss updates before checking criteria c | no |
Info (f=identity) | Log to Info the value of f(mach) , where mach is current machine | no |
Warn (predicate; f="") | Log to Warn the value of f or f(mach) , if predicate(mach) holds | no |
Error (predicate; f="") | Log to Error the value of f or f(mach) , if predicate(mach) holds and then stop | yes |
Callback (f=mach->nothing) | Call f(mach) | yes |
WithNumberDo (f=n->@info(n)) | Call f(n + 1) where n is the number of complete control cycles so far | yes |
WithIterationsDo (f=i->@info("iterations: $i")) | Call f(i) , where i is total number of iterations | yes |
WithLossDo (f=x->@info("loss: $x")) | Call f(loss) where loss is the current loss | yes |
WithTrainingLossesDo (f=v->@info(v)) | Call f(v) where v is the current batch of training losses | yes |
WithEvaluationDo (f->e->@info("evaluation: $e)) | Call f(e) where e is the current performance evaluation object | yes |
WithFittedParamsDo (f->fp->@info("fitted_params: $fp)) | Call f(fp) where fp is fitted parameters of training machine | yes |
WithReportDo (f->e->@info("report: $e)) | Call f(r) where r is the training machine report | yes |
WithModelDo (f->m->@info("model: $m)) | Call f(m) where m is the model, which may be mutated by f | yes |
WithMachineDo (f->mach->@info("report: $mach)) | Call f(mach) wher mach is the training machine in its current state | yes |
Save (filename="machine.jls") | Save current training machine to machine1.jls , machine2.jsl , etc | yes |
Table 1. Atomic controls. Some advanced options are omitted.
† For more on these controls see Prechelt, Lutz (1998): "Early Stopping - But When?", in Neural Networks: Tricks of the Trade, ed. G. Orr, Springer.
Stopping option. All the following controls trigger a stop if the provided function f
returns true
and stop_if_true=true
is specified in the constructor: Callback
, WithNumberDo
, WithLossDo
, WithTrainingLossesDo
.
There are also three control wrappers to modify a control's behavior:
wrapper | description |
---|---|
IterationControl.skip (control, predicate=1) | Apply control every predicate applications of the control wrapper (can also be a function; see doc-string) |
IterationControl.louder (control, by=1) | Increase the verbosity level of control by the specified value (negative values lower verbosity) |
IterationControl.with_state_do (control; f=...) | Apply control and call f(x) where x is the internal state of control; useful for debugging. Default f logs state to Info . Warning: internal control state is not yet part of the public API. |
IterationControl.composite (controls...) | Apply each control in controls in sequence; used internally by IterationControl.jl |
Table 2. Wrapped controls
Using training losses, and controlling model tuning
Some iterative models report a training loss, as a byproduct of a fit!
call and these can be used in two ways:
To supplement an out-of-sample estimate of the loss in deciding when to stop, as in the
PQ
stopping criterion (see Prechelt, Lutz (1998))); orAs a (generally less reliable) substitute for an out-of-sample loss, when wishing to train exclusively on all supplied data.
To have IteratedModel
bind all data to the training machine and use training losses in place of an out-of-sample loss, specify resampling=nothing
. To check if MyFavoriteIterativeModel
reports training losses, load the model code and inspect supports_training_losses(MyFavoriteIterativeModel)
(or do info("MyFavoriteIterativeModel")
)
Controlling model tuning
An example of scenario 2 occurs when controlling hyperparameter optimization (model tuning). Recall that MLJ's TunedModel
wrapper is implemented as an iterative model. Moreover, this wrapper reports, as a training loss, the lowest value of the optimization objective function so far (typically the lowest value of an out-of-sample loss, or -1 times an out-of-sample score). One may want to simply end the hyperparameter search when this value meets the NumberSinceBest
stopping criterion discussed below, say, rather than introducing an extra layer of resampling to first "learn" the optimal value of the iteration parameter.
In the following example, we conduct a RandomSearch
for the optimal value of the regularization parameter lambda
in a RidgeRegressor
using 6-fold cross-validation. By wrapping our "self-tuning" version of the regressor as an IteratedModel
, with resampling=nothing
and NumberSinceBest(20)
in the controls, we terminate the search when the number of lambda
values tested since the previous best cross-validation loss reaches 20.
using MLJ
X, y = @load_boston;
RidgeRegressor = @load RidgeRegressor pkg=MLJLinearModels verbosity=0
model = RidgeRegressor()
r = range(model, :lambda, lower=-1, upper=2, scale=x->10^x)
self_tuning_model = TunedModel(model=model,
tuning=RandomSearch(rng=123),
resampling=CV(nfolds=6),
range=r,
measure=mae);
iterated_model = IteratedModel(model=self_tuning_model,
resampling=nothing,
control=[Step(1), NumberSinceBest(20), NumberLimit(1000)])
mach = machine(iterated_model, X, y)
julia> fit!(mach)
[ Info: Training machine(DeterministicIteratedModel(model = DeterministicTunedModel(model = RidgeRegressor(lambda = 1.0, …), …), …), …). [ Info: No iteration parameter specified. Using `iteration_parameter=:(n)`. [ Info: final loss: 3.8928800658727467 [ Info: final training loss: 3.8928800658727467 [ Info: Stop triggered by NumberSinceBest(20) stopping criterion. [ Info: Total of 45 iterations. trained Machine; does not cache data model: DeterministicIteratedModel(model = DeterministicTunedModel(model = RidgeRegressor(lambda = 1.0, …), …), …) args: 1: Source @684 ⏎ Table{AbstractVector{Continuous}} 2: Source @923 ⏎ AbstractVector{Continuous}
julia> report(mach).model_report.best_model
RidgeRegressor( lambda = 0.4243170708090101, fit_intercept = true, penalize_intercept = false, scale_penalty_with_samples = true, solver = nothing)
We can use mach
here to directly obtain predictions using the optimal model (trained on all data), as in
julia> predict(mach, selectrows(X, 1:4))
4-element Vector{Float64}: 31.309570596541448 25.24911135120517 29.89525728277618 29.237112147518744
Custom controls
Under the hood, control in MLJIteration is implemented using IterationControl.jl. Rather than iterating a training machine directly, we iterate a wrapped version of this object, which includes other information that a control may want to access, such as the MLJ evaluation object. This information is summarized under The training machine wrapper below.
Controls must implement two update!
methods, one for initializing the control's state on the first application of the control (this state being external to the control struct
) and one for all subsequent control applications, which generally updates the state as well. There are two optional methods: done
, for specifying conditions triggering a stop, and takedown
for specifying actions to perform at the end of controlled training.
We summarize the training algorithm, as it relates to controls, after giving a simple example.
Example 1 - Non-uniform iteration steps
Below we define a control, IterateFromList(list)
, to train, on each application of the control, until the iteration count reaches the next value in a user-specified list
, triggering a stop when the list
is exhausted. For example, to train on iteration counts on a log scale, one might use IterateFromList([round(Int, 10^x) for x in range(1, 2, length=10)]
.
In the code, wrapper
is an object that wraps the training machine (see above). The variable n
is a counter for control cycles (unused in this example).
import IterationControl # or MLJ.IterationControl
struct IterateFromList
list::Vector{<:Int} # list of iteration parameter values
IterateFromList(v) = new(unique(sort(v)))
end
function IterationControl.update!(control::IterateFromList, wrapper, verbosity, n)
Δi = control.list[1]
verbosity > 1 && @info "Training $Δi more iterations. "
MLJIteration.train!(wrapper, Δi) # trains the training machine
return (index = 2, )
end
function IterationControl.update!(control::IterateFromList, wrapper, verbosity, n, state)
index = state.positioin_in_list
Δi = control.list[i] - wrapper.n_iterations
verbosity > 1 && @info "Training $Δi more iterations. "
MLJIteration.train!(wrapper, Δi)
return (index = index + 1, )
end
The first update
method will be called the first time the control is applied, returning an initialized state = (index = 2,)
, which is passed to the second update
method, which is called on subsequent control applications (and which returns the updated state
).
A done
method articulates the criterion for stopping:
IterationControl.done(control::IterateFromList, state) =
state.index > length(control.list)
For the sake of illustration, we'll implement a takedown
method; its return value is included in the IteratedModel
report:
IterationControl.takedown(control::IterateFromList, verbosity, state)
verbosity > 1 && = @info "Stepped through these values of the "*
"iteration parameter: $(control.list)"
return (iteration_values=control.list, )
end
The training machine wrapper
A training machine wrapper
has these properties:
wrapper.machine
- the training machine, typeMachine
wrapper.model
- the mutable atomic model, coinciding withwrapper.machine.model
wrapper.n_cycles
- the numberIterationControl.train!(wrapper, _)
calls so far; generally the current control cycle countwrapper.n_iterations
- the total number of iterations applied to the model so farwrapper.Δiterations
- the number of iterations applied in the lastIterationControl.train!(wrapper, _)
callwrapper.loss
- the out-of-sample loss (based on the first measure inmeasures
)wrapper.training_losses
- the last batch of training losses (if reported bymodel
), an abstract vector of lengthwrapper.Δiteration
.wrapper.evaluation
- the complete MLJ performance evaluation object, which has the following properties:measure
,measurement
,per_fold
,per_observation
,fitted_params_per_fold
,report_per_fold
(here there is only one fold). For further details, see Evaluating Model Performance.
The training algorithm
Here now is a simplified description of the training of an IteratedModel
. First, the atomic model
is bound in a machine - the training machine above - to a subset of the supplied data, and then wrapped in an object called wrapper
below. To train the training machine machine for i
more iterations, and update the other data in the wrapper, requires the call MLJIteration.train!(wrapper, i)
. Only controls can make this call (e.g., Step(...)
, or IterateFromList(...)
above). If we assume for simplicity there is only a single control, called control
, then training proceeds as follows:
n = 1 # initialize control cycle counter
state = update!(control, wrapper, verbosity, n)
finished = done(control, state)
# subsequent training events:
while !finished
n += 1
state = update!(control, wrapper, verbosity, n, state)
finished = done(control, state)
end
# finalization:
return takedown(control, verbosity, state)
Example 2 - Cyclic learning rates
The control below implements a triangular cyclic learning rate policy, close to that described in L. N. Smith (2019): "Cyclical Learning Rates for Training Neural Networks," 2017 IEEE Winter Conference on Applications of Computer Vision (WACV), Santa Rosa, CA, USA, pp. 464-472. [In that paper learning rates are mutated (slowly) during each training iteration (epoch) while here mutations can occur once per iteration of the model, at most.]
For the sake of illustration, we suppose the iterative model, model
, specified in the IteratedModel
constructor, has a field called :learning_parameter
, and that mutating this parameter does not trigger cold-restarts.
struct CycleLearningRate{F<:AbstractFloat}
stepsize::Int
lower::F
upper::F
end
# return one cycle of learning rate values:
function one_cycle(c::CycleLearningRate)
rise = range(c.lower, c.upper, length=c.stepsize + 1)
fall = reverse(rise)
return vcat(rise[1:end - 1], fall[1:end - 1])
end
function IterationControl.update!(control::CycleLearningRate,
wrapper,
verbosity,
n,
state = (learning_rates=nothing, ))
rates = n == 0 ? one_cycle(control) : state.learning_rates
index = mod(n, length(rates)) + 1
r = rates[index]
verbosity > 1 && @info "learning rate: $r"
wrapper.model.iteration_control = r
return (learning_rates = rates,)
end
API Reference
MLJIteration.IteratedModel
— FunctionIteratedModel(model;
controls=MLJIteration.DEFAULT_CONTROLS,
resampling=Holdout(),
measure=nothing,
retrain=false,
advanced_options...,
)
Wrap the specified supervised model
in the specified iteration controls
. Here model
should support iteration, which is true if (iteration_parameter(model)
is different from nothing
.
Available controls: Step(), Info(), Warn(), Error(), Callback(), WithLossDo(), WithTrainingLossesDo(), WithNumberDo(), Data(), Disjunction(), GL(), InvalidValue(), Never(), NotANumber(), NumberLimit(), NumberSinceBest(), PQ(), Patience(), Threshold(), TimeLimit(), Warmup(), WithIterationsDo(), WithEvaluationDo(), WithFittedParamsDo(), WithReportDo(), WithMachineDo(), WithModelDo(), CycleLearningRate() and Save().
To make out-of-sample losses available to the controls, the wrapped model
is only trained on part of the data, as iteration proceeds. The user may want to force retraining on all data after controlled iteration has finished by specifying retrain=true
. See also "Training", and the retrain
option, under "Extended help" below.
Extended help
Options
controls=Any[Step(1), Patience(5), GL(2.0), TimeLimit(Dates.Millisecond(108000)), InvalidValue()]
: Controls are summarized at https://JuliaAI.github.io/MLJ.jl/dev/getting_started/ but query individual doc-strings for details and advanced options. For creating your own controls, refer to the documentation just cited.resampling=Holdout(fraction_train=0.7)
: The default resampling holds back 30% of data for computing an out-of-sample estimate of performance (the "loss") for loss-based controls such asWithLossDo
. Specifyresampling=nothing
if all data is to be used for controlled iteration, with each out-of-sample loss replaced by the most recent training loss, assuming this is made available by the model (supports_training_losses(model) == true
). If the model does not report a training loss, you can useresampling=InSample()
instead. Otherwise,resampling
must have typeHoldout
or be a vector with one element of the form(train_indices, test_indices)
.measure=nothing
: StatisticalMeasures.jl compatible measure for estimating model performance (the "loss", but the orientation is immaterial - i.e., this could be a score). Inferred by default. Ignored ifresampling=nothing
.retrain=false
: Ifretrain=true
orresampling=nothing
,iterated_model
behaves exactly like the originalmodel
but with the iteration parameter automatically selected ("learned"). That is, the model is retrained on all available data, using the same number of iterations, once controlled iteration has stopped. This is typically desired if wrapping the iterated model further, or when inserting in a pipeline or other composite model. Ifretrain=false
(default) andresampling isa Holdout
, theniterated_model
behaves like the original model trained on a subset of the provided data.weights=nothing
: per-observation weights to be passed tomeasure
where supported; if unspecified, these are understood to be uniform.class_weights=nothing
: class-weights to be passed tomeasure
where supported; if unspecified, these are understood to be uniform.operation=nothing
: Operation, such aspredict
orpredict_mode
, for computing target values, or proxy target values, for consumption bymeasure
; automatically inferred by default.check_measure=true
: Specifyfalse
to override checks onmeasure
for compatibility with the training data.iteration_parameter=nothing
: A symbol, such as:epochs
, naming the iteration parameter ofmodel
; inferred by default. Note that the actual value of the iteration parameter in the suppliedmodel
is ignored; only the value of an internal clone is mutated during training the wrapped model.cache=true
: Whether or not model-specific representations of data are cached in between iteration parameter increments; specifycache=false
to prioritize memory over speed.
Training
Training an instance iterated_model
of IteratedModel
on some data
(by binding to a machine and calling fit!
, for example) performs the following actions:
Assuming
resampling !== nothing
, thedata
is split into train and test sets, according to the specifiedresampling
strategy.A clone of the wrapped model,
model
is bound to the train data in an internal machine,train_mach
. Ifresampling === nothing
, all data is used instead. This machine is the object to which controls are applied. For example,Callback(fitted_params |> print)
will print the value offitted_params(train_mach)
.The iteration parameter of the clone is set to
0
.The specified
controls
are repeatedly applied totrain_mach
in sequence, until one of the controls triggers a stop. Loss-based controls (eg,Patience()
,GL()
,Threshold(0.001)
) use an out-of-sample loss, obtained by applyingmeasure
to predictions and the test target values. (Specifically, these predictions are those returned byoperation(train_mach)
.) Ifresampling === nothing
then the most recent training loss is used instead. Some controls require both out-of-sample and training losses (eg,PQ()
).Once a stop has been triggered, a clone of
model
is bound to alldata
in a machine calledmach_production
below, unlessretrain == false
(true by default) orresampling === nothing
, in which casemach_production
coincides withtrain_mach
.
Prediction
Calling predict(mach, Xnew)
in the example above returns predict(mach_production, Xnew)
. Similar similar statements hold for predict_mean
, predict_mode
, predict_median
.
Controls that mutate parameters
A control is permitted to mutate the fields (hyper-parameters) of train_mach.model
(the clone of model
). For example, to mutate a learning rate one might use the control
Callback(mach -> mach.model.eta = 1.05*mach.model.eta)
However, unless model
supports warm restarts with respect to changes in that parameter, this will trigger retraining of train_mach
from scratch, with a different training outcome, which is not recommended.
Warm restarts
In the following example, the second fit!
call will not restart training of the internal train_mach
, assuming model
supports warm restarts:
iterated_model = IteratedModel(
model,
controls = [Step(1), NumberLimit(100)],
)
mach = machine(iterated_model, X, y)
fit!(mach) # train for 100 iterations
iterated_model.controls = [Step(1), NumberLimit(50)],
fit!(mach) # train for an *extra* 50 iterations
More generally, if iterated_model
is mutated and fit!(mach)
is called again, then a warm restart is attempted if the only parameters to change are model
or controls
or both.
Specifically, train_mach.model
is mutated to match the current value of iterated_model.model
and the iteration parameter of the latter is updated to the last value used in the preceding fit!(mach)
call. Then repeated application of the (updated) controls begin anew.
Controls
IterationControl.Step
— TypeStep(; n=1)
An iteration control, as in, Step(2)
.
Train for n
more iterations. Will never trigger a stop.
EarlyStopping.TimeLimit
— TypeTimeLimit(; t=0.5)
An early stopping criterion for loss-reporting iterative algorithms.
Stopping is triggered after t
hours have elapsed since the stopping criterion was initiated.
Any Julia built-in Real
type can be used for t
. Subtypes of Period
may also be used, as in TimeLimit(t=Minute(30))
.
Internally, t
is rounded to nearest millisecond. ``
EarlyStopping.NumberLimit
— TypeNumberLimit(; n=100)
An early stopping criterion for loss-reporting iterative algorithms.
A stop is triggered by n
consecutive loss updates, excluding "training" loss updates.
If wrapped in a stopper::EarlyStopper
, this is the number of calls to done!(stopper)
.
EarlyStopping.NumberSinceBest
— TypeNumberSinceBest(; n=6)
An early stopping criterion for loss-reporting iterative algorithms.
A stop is triggered when the number of calls to the control, since the lowest value of the loss so far, is n
.
For a customizable loss-based stopping criterion, use WithLossDo
or WithTrainingLossesDo
with the stop_if_true=true
option.
EarlyStopping.InvalidValue
— TypeInvalidValue()
An early stopping criterion for loss-reporting iterative algorithms.
Stop if a loss (or training loss) is NaN
, Inf
or -Inf
(or, more precisely, if isnan(loss)
or isinf(loss)
is true
).
For a customizable loss-based stopping criterion, use WithLossDo
or WithTrainingLossesDo
with the stop_if_true=true
option.
EarlyStopping.Threshold
— TypeThreshold(; value=0.0)
An early stopping criterion for loss-reporting iterative algorithms.
A stop is triggered as soon as the loss drops below value
.
For a customizable loss-based stopping criterion, use WithLossDo
or WithTrainingLossesDo
with the stop_if_true=true
option.
EarlyStopping.GL
— TypeGL(; alpha=2.0)
An early stopping criterion for loss-reporting iterative algorithms.
A stop is triggered when the (rescaled) generalization loss exceeds the threshold alpha
.
Terminology. Suppose $E_1, E_2, ..., E_t$ are a sequence of losses, for example, out-of-sample estimates of the loss associated with some iterative machine learning algorithm. Then the generalization loss at time t
, is given by
$GL_t = 100 (E_t - E_{opt}) \over |E_{opt}|$
where $E_{opt}$ is the minimum value of the sequence.
EarlyStopping.PQ
— TypePQ(; alpha=0.75, k=5, tol=eps(Float64))
A stopping criterion for training iterative supervised learners.
A stop is triggered when Prechelt's progress-modified generalization loss exceeds the threshold $PQ_T > alpha$, or if the training progress drops below $P_j ≤ tol$. Here k
is the number of training (in-sample) losses used to estimate the training progress.
Context and explanation of terminology
The training progress at time $j$ is defined by
$P_j = 1000 |M - m|/|m|$
where $M$ is the mean of the last k
training losses $F_1, F_2, …, F_k$ and $m$ is the minimum value of those losses.
The progress-modified generalization loss at time $t$ is then given by
$PQ_t = GL_t / P_t$
where $GL_t$ is the generalization loss at time $t$; see GL
.
PQ will stop when the following are true:
- At least
k
training samples have been collected viadone!(c::PQ, loss; training = true)
orupdate_training(c::PQ, loss, state)
- The last update was an out-of-sample update. (
done!(::PQ, loss; training=true)
is always false) - The progress-modified generalization loss exceeds the threshold $PQ_t > alpha$ OR the training progress stalls $P_j ≤ tol$.
EarlyStopping.Patience
— TypePatience(; n=5)
An early stopping criterion for loss-reporting iterative algorithms.
A stop is triggered by n
consecutive increases in the loss.
Denoted "UPs" in Prechelt, Lutz (1998): "Early Stopping- But When?", in Neural Networks: Tricks of the Trade, ed. G. Orr, Springer..
For a customizable loss-based stopping criterion, use WithLossDo
or WithTrainingLossesDo
with the stop_if_true=true
option.
IterationControl.Info
— TypeInfo(f=identity)
An iteration control, as in, Info(my_loss_function)
.
Log to Info
the value of f(m)
, where m
is the object being iterated. If IterativeControl.expose(m)
has been overloaded, then log f(expose(m))
instead.
Can be suppressed by setting the global verbosity level sufficiently low.
IterationControl.Warn
— TypeWarn(predicate; f="")
An iteration control, as in, Warn(m -> length(m.cache) > 100, f="Memory low")
.
If predicate(m)
is true
, then log to Warn
the value of f
(or f(IterationControl.expose(m))
if f
is a function). Here m
is the object being iterated.
Can be suppressed by setting the global verbosity level sufficiently low.
IterationControl.Error
— TypeError(predicate; f="", exception=nothing))
An iteration control, as in, Error(m -> isnan(m.bias), f="Bias overflow!")
.
If predicate(m)
is true
, then log at the Error
level the value of f
(or f(IterationControl.expose(m))
if f
is a function) and stop iteration at the end of the current control cycle. Here m
is the object being iterated.
Specify exception=...
to throw an immediate execption, without waiting to the end of the control cycle.
IterationControl.Callback
— TypeCallback(f=_->nothing, stop_if_true=false, stop_message=nothing, raw=false)
An iteration control, as in, Callback(m->put!(v, my_loss_function(m))
.
Call f(IterationControl.expose(m))
, where m
is the object being iterated, unless raw=true
, in which case call f(m)
(guaranteed if expose
has not been overloaded.) If stop_if_true
is true
, then trigger an early stop if the value returned by f
is true
, logging the stop_message
if specified.
IterationControl.WithNumberDo
— TypeWithNumberDo(f=n->@info("number: $n"), stop_if_true=false, stop_message=nothing)
An iteration control, as in, WithNumberDo(n->put!(my_channel, n))
.
Call f(n + 1)
, where n
is the number of complete control cycles. of the control (so, n = 1, 2, 3, ...
, unless control is wrapped in a IterationControl.skip
)`.
If stop_if_true
is true
, then trigger an early stop if the value returned by f
is true
, logging the stop_message
if specified.
MLJIteration.WithIterationsDo
— TypeWithIterationsDo(f=x->@info("iterations: $x"), stop_if_true=false, stop_message=nothing)
An iteration control, as in, WithIterationsDo(x->put!(my_channel, x))
.
Call f(x)
, where x
is the current number of model iterations (generally more than the number of control cycles). If stop_if_true
is true
, then trigger an early stop if the value returned by f
is true
, logging the stop_message
if specified.
IterationControl.WithLossDo
— TypeWithLossDo(f=x->@info("loss: $x"), stop_if_true=false, stop_message=nothing)
An iteration control, as in, WithLossDo(x->put!(my_losses, x))
.
Call f(loss)
, where loss
is current loss.
If stop_if_true
is true
, then trigger an early stop if the value returned by f
is true
, logging the stop_message
if specified.
IterationControl.WithTrainingLossesDo
— TypeWithTrainingLossesDo(f=v->@info("training: $v"), stop_if_true=false, stop_message=nothing)
An iteration control, as in, WithTrainingLossesDo(v->put!(my_losses, last(v))
.
Call f(training_losses)
, where training_losses
is the vector of most recent batch of training losses.
If stop_if_true
is true
, then trigger an early stop if the value returned by f
is true
, logging the stop_message
if specified.
MLJIteration.WithEvaluationDo
— TypeWithEvaluationDo(f=x->@info("evaluation: $x"), stop_if_true=false, stop_message=nothing)
An iteration control, as in, WithEvaluationDo(x->put!(my_channel, x))
.
Call f(x)
, where x
is the latest performance evaluation, as returned by evaluate!(train_mach, resampling=..., ...)
. Not valid if resampling=nothing
. If stop_if_true
is true
, then trigger an early stop if the value returned by f
is true
, logging the stop_message
if specified.
MLJIteration.WithFittedParamsDo
— TypeWithFittedParamsDo(f=x->@info("fitted_params: $x"), stop_if_true=false, stop_message=nothing)
An iteration control, as in, WithFittedParamsDo(x->put!(my_channel, x))
.
Call f(x)
, where x = fitted_params(mach)
is the fitted parameters of the training machine, mach
, in its current state. If stop_if_true
is true
, then trigger an early stop if the value returned by f
is true
, logging the stop_message
if specified.
MLJIteration.WithReportDo
— TypeWithReportDo(f=x->@info("report: $x"), stop_if_true=false, stop_message=nothing)
An iteration control, as in, WithReportDo(x->put!(my_channel, x))
.
Call f(x)
, where x = report(mach)
is the report associated with the training machine, mach
, in its current state. If stop_if_true
is true
, then trigger an early stop if the value returned by f
is true
, logging the stop_message
if specified.
MLJIteration.WithModelDo
— TypeWithModelDo(f=x->@info("model: $x"), stop_if_true=false, stop_message=nothing)
An iteration control, as in, WithModelDo(x->put!(my_channel, x))
.
Call f(x)
, where x
is the model associated with the training machine; f
may mutate x
, as in f(x) = (x.learning_rate *= 0.9)
. If stop_if_true
is true
, then trigger an early stop if the value returned by f
is true
, logging the stop_message
if specified.
MLJIteration.WithMachineDo
— TypeWithMachineDo(f=x->@info("machine: $x"), stop_if_true=false, stop_message=nothing)
An iteration control, as in, WithMachineDo(x->put!(my_channel, x))
.
Call f(x)
, where x
is the training machine in its current state. If stop_if_true
is true
, then trigger an early stop if the value returned by f
is true
, logging the stop_message
if specified.
MLJIteration.Save
— TypeSave(filename="machine.jls")
An iteration control, as in, Save("run3/machine.jls")
.
Save the current state of the machine being iterated to disk, using the provided filename
, decorated with a number, as in "run3/machine42.jls". The default behaviour uses the Serialization module but this can be changed by setting the method=save_fn(::String, ::Any)
argument where save_fn
is any serialization method. For more on what is meant by "the machine being iterated", see IteratedModel
.
Control wrappers
IterationControl.skip
— FunctionIterationControl.skip(control, predicate=1)
An iteration control wrapper.
If predicate
is an integer, k
: Apply control
on every k
calls to apply the wrapped control, starting with the k
th call.
If predicate
is a function: Apply control
as usual when predicate(n + 1)
is true
but otherwise skip. Here n
is the number of control cycles applied so far.
IterationControl.louder
— FunctionIterationControl.louder(control, by=1)
Wrap control
to make in more (or less) verbose. The same as control
, but as if the global verbosity
were increased by the value by
.
IterationControl.with_state_do
— FunctionIterationControl.with_state_do(control,
f=x->@info "$(typeof(control)) state: $x")
Wrap control
to give access to it's internal state. Acts exactly like control
except that f
is called on the internal state of control
. If f
is not specified, the control type and state are logged to Info
at every update (useful for debugging new controls).
Warning. The internal state of a control is not yet considered part of the public interface and could change between in any pre 1.0 release of IterationControl.jl.
IterationControl.composite
— Functioncomposite(controls...)
Construct an iteration control that applies the specified controls
in sequence.