Controlling Iterative Models

Iterative supervised machine learning models are usually trained until an out-of-sample estimate of the performance satisfies some stopping criterion, such as k consecutive deteriorations of the performance (see Patience below). A more sophisticated kind of control might dynamically mutate parameters, such as a learning rate, in response to the behavior of these estimates.

Some iterative model implementations enable some form of automated control, with the method and options for doing so varying from model to model. But sometimes it is up to the user to arrange control, which in the crudest case reduces to manually experimenting with the iteration parameter.

In response to this ad hoc state of affairs, MLJ provides a uniform and feature-rich interface for controlling any iterative model that exposes its iteration parameter as a hyper-parameter, and which implements the "warm restart" behavior described in Machines.

Basic use

As in Tuning Models, iteration control in MLJ is implemented as a model wrapper, which allows composition with other meta-algorithms. Ordinarily, the wrapped model behaves just like the original model, but with the training occurring on a subset of the provided data (to allow computation of an out-of-sample loss) and with the iteration parameter automatically determined by the controls specified in the wrapper.

By setting retrain=true one can ask that the wrapped model retrain on all supplied data, after learning the appropriate number of iterations from the controlled training phase:

using MLJ

X, y = make_moons(100, rng=123, noise=0.5)
EvoTreeClassifier = @load EvoTreeClassifier verbosity=0

iterated_model = IteratedModel(model=EvoTreeClassifier(rng=123, eta=0.005),
                               resampling=Holdout(),
                               measures=log_loss,
                               controls=[Step(5),
                                         Patience(2),
                                         NumberLimit(100)],
                               retrain=true)

mach = machine(iterated_model, X, y)
julia> fit!(mach)┌ Info: Training machine(ProbabilisticIteratedModel(model = EvoTrees.EvoTreeClassifier{EvoTrees.MLogLoss}
│  - nrounds: 100
│  - L2: 0.0
│  - lambda: 0.0
│  - gamma: 0.0
│  - eta: 0.005
│  - max_depth: 6
│  - min_weight: 1.0
│  - rowsample: 1.0
│  - colsample: 1.0
│  - nbins: 64
│  - alpha: 0.5
│  - tree_type: binary
│  - rng: Random.MersenneTwister(123)
└ , …), …).
[ Info: No iteration parameter specified. Using `iteration_parameter=:(nrounds)`. 
[ Info: final loss: 0.5473439182480309
[ Info: Stop triggered by Patience(2) stopping criterion. 
[ Info: Retraining on all provided data. To suppress, specify `retrain=false`. 
[ Info: Total of 120 iterations. 
trained Machine; does not cache data
  model: ProbabilisticIteratedModel(model = EvoTrees.EvoTreeClassifier{EvoTrees.MLogLoss}
 - nrounds: 100
 - L2: 0.0
 - lambda: 0.0
 - gamma: 0.0
 - eta: 0.005
 - max_depth: 6
 - min_weight: 1.0
 - rowsample: 1.0
 - colsample: 1.0
 - nbins: 64
 - alpha: 0.5
 - tree_type: binary
 - rng: Random.MersenneTwister(123)
, …)
  args:
    1:	Source @587 ⏎ Table{AbstractVector{Continuous}}
    2:	Source @687 ⏎ AbstractVector{Multiclass{2}}

As detailed under IteratedModel below, the specified controls are repeatedly applied in sequence to a training machine, constructed under the hood, until one of the controls triggers a stop. Here Step(5) means "Compute 5 more iterations" (in this case starting from none); Patience(2) means "Stop at the end of the control cycle if there have been 2 consecutive drops in the log loss"; and NumberLimit(100) is a safeguard ensuring a stop after 100 control cycles (500 iterations). See Controls provided below for a complete list.

Because iteration is implemented as a wrapper, the "self-iterating" model can be evaluated using cross-validation, say, and the number of iterations on each fold will generally be different:

e = evaluate!(mach, resampling=CV(nfolds=3), measure=log_loss, verbosity=0);
map(e.report_per_fold) do r
    r.n_iterations
end
3-element Vector{Int64}:
 350
  60
 180

Alternatively, one might wrap the self-iterating model in a tuning strategy, using TunedModel; see Tuning Models. In this way, the optimization of some other hyper-parameter is realized simultaneously with that of the iteration parameter, which will frequently be more efficient than a direct two-parameter search.

Controls provided

In the table below, mach is the training machine being iterated, constructed by binding the supplied data to the model specified in the IteratedModel wrapper, but trained in each iteration on a subset of the data, according to the value of the resampling hyper-parameter of the wrapper (using all data if resampling=nothing).

controldescriptioncan trigger a stop
Step(n=1)Train model for n more iterationsno
TimeLimit(t=0.5)Stop after t hoursyes
NumberLimit(n=100)Stop after n applications of the controlyes
NumberSinceBest(n=6)Stop when best loss occurred n control applications agoyes
InvalidValue()Stop when NaN, Inf or -Inf loss/training loss encounteredyes
Threshold(value=0.0)Stop when loss < valueyes
GL(alpha=2.0)† Stop after the "generalization loss (GL)" exceeds alphayes
PQ(alpha=0.75, k=5)† Stop after "progress-modified GL" exceeds alphayes
Patience(n=5)† Stop after n consecutive loss increasesyes
Warmup(c; n=1)Wait for n loss updates before checking criteria cno
Info(f=identity)Log to Info the value of f(mach), where mach is current machineno
Warn(predicate; f="")Log to Warn the value of f or f(mach), if predicate(mach) holdsno
Error(predicate; f="")Log to Error the value of f or f(mach), if predicate(mach) holds and then stopyes
Callback(f=mach->nothing)Call f(mach)yes
WithNumberDo(f=n->@info(n))Call f(n + 1) where n is the number of complete control cycles so faryes
WithIterationsDo(f=i->@info("iterations: $i"))Call f(i), where i is total number of iterationsyes
WithLossDo(f=x->@info("loss: $x"))Call f(loss) where loss is the current lossyes
WithTrainingLossesDo(f=v->@info(v))Call f(v) where v is the current batch of training lossesyes
WithEvaluationDo(f->e->@info("evaluation: $e))Call f(e) where e is the current performance evaluation objectyes
WithFittedParamsDo(f->fp->@info("fitted_params: $fp))Call f(fp) where fp is fitted parameters of training machineyes
WithReportDo(f->e->@info("report: $e))Call f(r) where r is the training machine reportyes
WithModelDo(f->m->@info("model: $m))Call f(m) where m is the model, which may be mutated by fyes
WithMachineDo(f->mach->@info("report: $mach))Call f(mach) wher mach is the training machine in its current stateyes
Save(filename="machine.jls")Save current training machine to machine1.jls, machine2.jsl, etcyes

Table 1. Atomic controls. Some advanced options are omitted.

† For more on these controls see Prechelt, Lutz (1998): "Early Stopping - But When?", in Neural Networks: Tricks of the Trade, ed. G. Orr, Springer.

Stopping option. All the following controls trigger a stop if the provided function f returns true and stop_if_true=true is specified in the constructor: Callback, WithNumberDo, WithLossDo, WithTrainingLossesDo.

There are also three control wrappers to modify a control's behavior:

wrapperdescription
IterationControl.skip(control, predicate=1)Apply control every predicate applications of the control wrapper (can also be a function; see doc-string)
IterationControl.louder(control, by=1)Increase the verbosity level of control by the specified value (negative values lower verbosity)
IterationControl.with_state_do(control; f=...)Apply control and call f(x) where x is the internal state of control; useful for debugging. Default f logs state to Info. Warning: internal control state is not yet part of the public API.
IterationControl.composite(controls...)Apply each control in controls in sequence; used internally by IterationControl.jl

Table 2. Wrapped controls

Using training losses, and controlling model tuning

Some iterative models report a training loss, as a byproduct of a fit! call and these can be used in two ways:

  1. To supplement an out-of-sample estimate of the loss in deciding when to stop, as in the PQ stopping criterion (see Prechelt, Lutz (1998))); or

  2. As a (generally less reliable) substitute for an out-of-sample loss, when wishing to train exclusively on all supplied data.

To have IteratedModel bind all data to the training machine and use training losses in place of an out-of-sample loss, specify resampling=nothing. To check if MyFavoriteIterativeModel reports training losses, load the model code and inspect supports_training_losses(MyFavoriteIterativeModel) (or do info("MyFavoriteIterativeModel"))

Controlling model tuning

An example of scenario 2 occurs when controlling hyperparameter optimization (model tuning). Recall that MLJ's TunedModel wrapper is implemented as an iterative model. Moreover, this wrapper reports, as a training loss, the lowest value of the optimization objective function so far (typically the lowest value of an out-of-sample loss, or -1 times an out-of-sample score). One may want to simply end the hyperparameter search when this value meets the NumberSinceBest stopping criterion discussed below, say, rather than introducing an extra layer of resampling to first "learn" the optimal value of the iteration parameter.

In the following example, we conduct a RandomSearch for the optimal value of the regularization parameter lambda in a RidgeRegressor using 6-fold cross-validation. By wrapping our "self-tuning" version of the regressor as an IteratedModel, with resampling=nothing and NumberSinceBest(20) in the controls, we terminate the search when the number of lambda values tested since the previous best cross-validation loss reaches 20.

using MLJ

X, y = @load_boston;
RidgeRegressor = @load RidgeRegressor pkg=MLJLinearModels verbosity=0
model = RidgeRegressor()
r = range(model, :lambda, lower=-1, upper=2, scale=x->10^x)
self_tuning_model = TunedModel(model=model,
                               tuning=RandomSearch(rng=123),
                               resampling=CV(nfolds=6),
                               range=r,
                               measure=mae);
iterated_model = IteratedModel(model=self_tuning_model,
                               resampling=nothing,
                               control=[Step(1), NumberSinceBest(20), NumberLimit(1000)])
mach = machine(iterated_model, X, y)
julia> fit!(mach)[ Info: Training machine(DeterministicIteratedModel(model = DeterministicTunedModel(model = RidgeRegressor(lambda = 1.0, …), …), …), …).
[ Info: No iteration parameter specified. Using `iteration_parameter=:(n)`. 
[ Info: final loss: 3.8933310676958963
[ Info: final training loss: 3.8933310676958963
[ Info: Stop triggered by NumberSinceBest(20) stopping criterion. 
[ Info: Total of 60 iterations. 
trained Machine; does not cache data
  model: DeterministicIteratedModel(model = DeterministicTunedModel(model = RidgeRegressor(lambda = 1.0, …), …), …)
  args:
    1:	Source @881 ⏎ Table{AbstractVector{Continuous}}
    2:	Source @090 ⏎ AbstractVector{Continuous}
julia> report(mach).model_report.best_modelRidgeRegressor(
  lambda = 0.3809484010529411,
  fit_intercept = true,
  penalize_intercept = false,
  scale_penalty_with_samples = true,
  solver = nothing)

We can use mach here to directly obtain predictions using the optimal model (trained on all data), as in

julia> predict(mach, selectrows(X, 1:4))4-element Vector{Float64}:
 31.305322173987363
 25.209618923394196
 29.92196911451383
 29.232826153862668

Custom controls

Under the hood, control in MLJIteration is implemented using IterationControl.jl. Rather than iterating a training machine directly, we iterate a wrapped version of this object, which includes other information that a control may want to access, such as the MLJ evaluation object. This information is summarized under The training machine wrapper below.

Controls must implement two update! methods, one for initializing the control's state on the first application of the control (this state being external to the control struct) and one for all subsequent control applications, which generally updates the state as well. There are two optional methods: done, for specifying conditions triggering a stop, and takedown for specifying actions to perform at the end of controlled training.

We summarize the training algorithm, as it relates to controls, after giving a simple example.

Example 1 - Non-uniform iteration steps

Below we define a control, IterateFromList(list), to train, on each application of the control, until the iteration count reaches the next value in a user-specified list, triggering a stop when the list is exhausted. For example, to train on iteration counts on a log scale, one might use IterateFromList([round(Int, 10^x) for x in range(1, 2, length=10)].

In the code, wrapper is an object that wraps the training machine (see above). The variable n is a counter for control cycles (unused in this example).

import IterationControl # or MLJ.IterationControl

struct IterateFromList
    list::Vector{<:Int} # list of iteration parameter values
    IterateFromList(v) = new(unique(sort(v)))
end

function IterationControl.update!(control::IterateFromList, wrapper, verbosity, n)
    Δi = control.list[1]
    verbosity > 1 && @info "Training $Δi more iterations. "
    MLJIteration.train!(wrapper, Δi) # trains the training machine
    return (index = 2, )
end

function IterationControl.update!(control::IterateFromList, wrapper, verbosity, n, state)
    index = state.positioin_in_list
    Δi = control.list[i] - wrapper.n_iterations
    verbosity > 1 && @info "Training $Δi more iterations. "
    MLJIteration.train!(wrapper, Δi)
    return (index = index + 1, )
end

The first update method will be called the first time the control is applied, returning an initialized state = (index = 2,), which is passed to the second update method, which is called on subsequent control applications (and which returns the updated state).

A done method articulates the criterion for stopping:

IterationControl.done(control::IterateFromList, state) =
    state.index > length(control.list)

For the sake of illustration, we'll implement a takedown method; its return value is included in the IteratedModel report:

IterationControl.takedown(control::IterateFromList, verbosity, state)
    verbosity > 1 && = @info "Stepped through these values of the "*
                              "iteration parameter: $(control.list)"
    return (iteration_values=control.list, )
end

The training machine wrapper

A training machine wrapper has these properties:

  • wrapper.machine - the training machine, type Machine

  • wrapper.model - the mutable atomic model, coinciding with wrapper.machine.model

  • wrapper.n_cycles - the number IterationControl.train!(wrapper, _) calls so far; generally the current control cycle count

  • wrapper.n_iterations - the total number of iterations applied to the model so far

  • wrapper.Δiterations - the number of iterations applied in the last IterationControl.train!(wrapper, _) call

  • wrapper.loss - the out-of-sample loss (based on the first measure in measures)

  • wrapper.training_losses - the last batch of training losses (if reported by model), an abstract vector of length wrapper.Δiteration.

  • wrapper.evaluation - the complete MLJ performance evaluation object, which has the following properties: measure, measurement, per_fold, per_observation, fitted_params_per_fold, report_per_fold (here there is only one fold). For further details, see Evaluating Model Performance.

The training algorithm

Here now is a simplified description of the training of an IteratedModel. First, the atomic model is bound in a machine - the training machine above - to a subset of the supplied data, and then wrapped in an object called wrapper below. To train the training machine machine for i more iterations, and update the other data in the wrapper, requires the call MLJIteration.train!(wrapper, i). Only controls can make this call (e.g., Step(...), or IterateFromList(...) above). If we assume for simplicity there is only a single control, called control, then training proceeds as follows:

n = 1 # initialize control cycle counter
state = update!(control, wrapper, verbosity, n)
finished = done(control, state)

# subsequent training events:
while !finished
    n += 1
    state = update!(control, wrapper, verbosity, n, state)
    finished = done(control, state)
end

# finalization:
return takedown(control, verbosity, state)

Example 2 - Cyclic learning rates

The control below implements a triangular cyclic learning rate policy, close to that described in L. N. Smith (2019): "Cyclical Learning Rates for Training Neural Networks," 2017 IEEE Winter Conference on Applications of Computer Vision (WACV), Santa Rosa, CA, USA, pp. 464-472. [In that paper learning rates are mutated (slowly) during each training iteration (epoch) while here mutations can occur once per iteration of the model, at most.]

For the sake of illustration, we suppose the iterative model, model, specified in the IteratedModel constructor, has a field called :learning_parameter, and that mutating this parameter does not trigger cold-restarts.

struct CycleLearningRate{F<:AbstractFloat}
    stepsize::Int
    lower::F
    upper::F
end

# return one cycle of learning rate values:
function one_cycle(c::CycleLearningRate)
    rise = range(c.lower, c.upper, length=c.stepsize + 1)
    fall = reverse(rise)
    return vcat(rise[1:end - 1], fall[1:end - 1])
end

function IterationControl.update!(control::CycleLearningRate,
                                  wrapper,
                                  verbosity,
                                  n,
                                  state = (learning_rates=nothing, ))
    rates = n == 0 ? one_cycle(control) : state.learning_rates
    index = mod(n, length(rates)) + 1
    r = rates[index]
    verbosity > 1 && @info "learning rate: $r"
    wrapper.model.iteration_control = r
    return (learning_rates = rates,)
end

API Reference

MLJIteration.IteratedModelFunction
IteratedModel(model;
    controls=MLJIteration.DEFAULT_CONTROLS,
    resampling=Holdout(),
    measure=nothing,
    retrain=false,
    advanced_options...,
)

Wrap the specified supervised model in the specified iteration controls. Here model should support iteration, which is true if (iteration_parameter(model) is different from nothing.

Available controls: Step(), Info(), Warn(), Error(), Callback(), WithLossDo(), WithTrainingLossesDo(), WithNumberDo(), Data(), Disjunction(), GL(), InvalidValue(), Never(), NotANumber(), NumberLimit(), NumberSinceBest(), PQ(), Patience(), Threshold(), TimeLimit(), Warmup(), WithIterationsDo(), WithEvaluationDo(), WithFittedParamsDo(), WithReportDo(), WithMachineDo(), WithModelDo(), CycleLearningRate() and Save().

Important

To make out-of-sample losses available to the controls, the wrapped model is only trained on part of the data, as iteration proceeds. The user may want to force retraining on all data after controlled iteration has finished by specifying retrain=true. See also "Training", and the retrain option, under "Extended help" below.

Extended help

Options

  • controls=Any[Step(1), Patience(5), GL(2.0), TimeLimit(Dates.Millisecond(108000)), InvalidValue()]: Controls are summarized at https://JuliaAI.github.io/MLJ.jl/dev/getting_started/ but query individual doc-strings for details and advanced options. For creating your own controls, refer to the documentation just cited.

  • resampling=Holdout(fraction_train=0.7): The default resampling holds back 30% of data for computing an out-of-sample estimate of performance (the "loss") for loss-based controls such as WithLossDo. Specify resampling=nothing if all data is to be used for controlled iteration, with each out-of-sample loss replaced by the most recent training loss, assuming this is made available by the model (supports_training_losses(model) == true). If the model does not report a training loss, you can use resampling=InSample() instead. Otherwise, resampling must have type Holdout or be a vector with one element of the form (train_indices, test_indices).

  • measure=nothing: StatisticalMeasures.jl compatible measure for estimating model performance (the "loss", but the orientation is immaterial - i.e., this could be a score). Inferred by default. Ignored if resampling=nothing.

  • retrain=false: If retrain=true or resampling=nothing, iterated_model behaves exactly like the original model but with the iteration parameter automatically selected ("learned"). That is, the model is retrained on all available data, using the same number of iterations, once controlled iteration has stopped. This is typically desired if wrapping the iterated model further, or when inserting in a pipeline or other composite model. If retrain=false (default) and resampling isa Holdout, then iterated_model behaves like the original model trained on a subset of the provided data.

  • weights=nothing: per-observation weights to be passed to measure where supported; if unspecified, these are understood to be uniform.

  • class_weights=nothing: class-weights to be passed to measure where supported; if unspecified, these are understood to be uniform.

  • operation=nothing: Operation, such as predict or predict_mode, for computing target values, or proxy target values, for consumption by measure; automatically inferred by default.

  • check_measure=true: Specify false to override checks on measure for compatibility with the training data.

  • iteration_parameter=nothing: A symbol, such as :epochs, naming the iteration parameter of model; inferred by default. Note that the actual value of the iteration parameter in the supplied model is ignored; only the value of an internal clone is mutated during training the wrapped model.

  • cache=true: Whether or not model-specific representations of data are cached in between iteration parameter increments; specify cache=false to prioritize memory over speed.

Training

Training an instance iterated_model of IteratedModel on some data (by binding to a machine and calling fit!, for example) performs the following actions:

  • Assuming resampling !== nothing, the data is split into train and test sets, according to the specified resampling strategy.

  • A clone of the wrapped model, model is bound to the train data in an internal machine, train_mach. If resampling === nothing, all data is used instead. This machine is the object to which controls are applied. For example, Callback(fitted_params |> print) will print the value of fitted_params(train_mach).

  • The iteration parameter of the clone is set to 0.

  • The specified controls are repeatedly applied to train_mach in sequence, until one of the controls triggers a stop. Loss-based controls (eg, Patience(), GL(), Threshold(0.001)) use an out-of-sample loss, obtained by applying measure to predictions and the test target values. (Specifically, these predictions are those returned by operation(train_mach).) If resampling === nothing then the most recent training loss is used instead. Some controls require both out-of-sample and training losses (eg, PQ()).

  • Once a stop has been triggered, a clone of model is bound to all data in a machine called mach_production below, unless retrain == false (true by default) or resampling === nothing, in which case mach_production coincides with train_mach.

Prediction

Calling predict(mach, Xnew) in the example above returns predict(mach_production, Xnew). Similar similar statements hold for predict_mean, predict_mode, predict_median.

Controls that mutate parameters

A control is permitted to mutate the fields (hyper-parameters) of train_mach.model (the clone of model). For example, to mutate a learning rate one might use the control

Callback(mach -> mach.model.eta = 1.05*mach.model.eta)

However, unless model supports warm restarts with respect to changes in that parameter, this will trigger retraining of train_mach from scratch, with a different training outcome, which is not recommended.

Warm restarts

In the following example, the second fit! call will not restart training of the internal train_mach, assuming model supports warm restarts:

iterated_model = IteratedModel(
    model,
    controls = [Step(1), NumberLimit(100)],
)
mach = machine(iterated_model, X, y)
fit!(mach) # train for 100 iterations
iterated_model.controls = [Step(1), NumberLimit(50)],
fit!(mach) # train for an *extra* 50 iterations

More generally, if iterated_model is mutated and fit!(mach) is called again, then a warm restart is attempted if the only parameters to change are model or controls or both.

Specifically, train_mach.model is mutated to match the current value of iterated_model.model and the iteration parameter of the latter is updated to the last value used in the preceding fit!(mach) call. Then repeated application of the (updated) controls begin anew.

source

Controls

IterationControl.StepType
Step(; n=1)

An iteration control, as in, Step(2).

Train for n more iterations. Will never trigger a stop.

source
EarlyStopping.TimeLimitType
TimeLimit(; t=0.5)

An early stopping criterion for loss-reporting iterative algorithms.

Stopping is triggered after t hours have elapsed since the stopping criterion was initiated.

Any Julia built-in Real type can be used for t. Subtypes of Period may also be used, as in TimeLimit(t=Minute(30)).

Internally, t is rounded to nearest millisecond. ``

source
EarlyStopping.NumberLimitType
NumberLimit(; n=100)

An early stopping criterion for loss-reporting iterative algorithms.

A stop is triggered by n consecutive loss updates, excluding "training" loss updates.

If wrapped in a stopper::EarlyStopper, this is the number of calls to done!(stopper).

source
EarlyStopping.NumberSinceBestType
NumberSinceBest(; n=6)

An early stopping criterion for loss-reporting iterative algorithms.

A stop is triggered when the number of calls to the control, since the lowest value of the loss so far, is n.

For a customizable loss-based stopping criterion, use WithLossDo or WithTrainingLossesDo with the stop_if_true=true option.

source
EarlyStopping.InvalidValueType
InvalidValue()

An early stopping criterion for loss-reporting iterative algorithms.

Stop if a loss (or training loss) is NaN, Inf or -Inf (or, more precisely, if isnan(loss) or isinf(loss) is true).

For a customizable loss-based stopping criterion, use WithLossDo or WithTrainingLossesDo with the stop_if_true=true option.

source
EarlyStopping.ThresholdType
Threshold(; value=0.0)

An early stopping criterion for loss-reporting iterative algorithms.

A stop is triggered as soon as the loss drops below value.

For a customizable loss-based stopping criterion, use WithLossDo or WithTrainingLossesDo with the stop_if_true=true option.

source
EarlyStopping.GLType
GL(; alpha=2.0)

An early stopping criterion for loss-reporting iterative algorithms.

A stop is triggered when the (rescaled) generalization loss exceeds the threshold alpha.

Terminology. Suppose $E_1, E_2, ..., E_t$ are a sequence of losses, for example, out-of-sample estimates of the loss associated with some iterative machine learning algorithm. Then the generalization loss at time t, is given by

$GL_t = 100 (E_t - E_{opt}) \over |E_{opt}|$

where $E_{opt}$ is the minimum value of the sequence.

Reference: Prechelt, Lutz (1998): "Early Stopping- But When?", in Neural Networks: Tricks of the Trade, ed. G. Orr, Springer..

source
EarlyStopping.PQType
PQ(; alpha=0.75, k=5, tol=eps(Float64))

A stopping criterion for training iterative supervised learners.

A stop is triggered when Prechelt's progress-modified generalization loss exceeds the threshold $PQ_T > alpha$, or if the training progress drops below $P_j ≤ tol$. Here k is the number of training (in-sample) losses used to estimate the training progress.

Context and explanation of terminology

The training progress at time $j$ is defined by

$P_j = 1000 |M - m|/|m|$

where $M$ is the mean of the last k training losses $F_1, F_2, …, F_k$ and $m$ is the minimum value of those losses.

The progress-modified generalization loss at time $t$ is then given by

$PQ_t = GL_t / P_t$

where $GL_t$ is the generalization loss at time $t$; see GL.

PQ will stop when the following are true:

  1. At least k training samples have been collected via done!(c::PQ, loss; training = true) or update_training(c::PQ, loss, state)
  2. The last update was an out-of-sample update. (done!(::PQ, loss; training=true) is always false)
  3. The progress-modified generalization loss exceeds the threshold $PQ_t > alpha$ OR the training progress stalls $P_j ≤ tol$.

Reference: Prechelt, Lutz (1998): "Early Stopping- But When?", in Neural Networks: Tricks of the Trade, ed. G. Orr, Springer..

source
IterationControl.InfoType
Info(f=identity)

An iteration control, as in, Info(my_loss_function).

Log to Info the value of f(m), where m is the object being iterated. If IterativeControl.expose(m) has been overloaded, then log f(expose(m)) instead.

Can be suppressed by setting the global verbosity level sufficiently low.

See also Warn, Error.

source
IterationControl.WarnType
Warn(predicate; f="")

An iteration control, as in, Warn(m -> length(m.cache) > 100, f="Memory low").

If predicate(m) is true, then log to Warn the value of f (or f(IterationControl.expose(m)) if f is a function). Here m is the object being iterated.

Can be suppressed by setting the global verbosity level sufficiently low.

See also Info, Error.

source
IterationControl.ErrorType
Error(predicate; f="", exception=nothing))

An iteration control, as in, Error(m -> isnan(m.bias), f="Bias overflow!").

If predicate(m) is true, then log at the Error level the value of f (or f(IterationControl.expose(m)) if f is a function) and stop iteration at the end of the current control cycle. Here m is the object being iterated.

Specify exception=... to throw an immediate execption, without waiting to the end of the control cycle.

See also Info, Warn.

source
IterationControl.CallbackType
Callback(f=_->nothing, stop_if_true=false, stop_message=nothing, raw=false)

An iteration control, as in, Callback(m->put!(v, my_loss_function(m)).

Call f(IterationControl.expose(m)), where m is the object being iterated, unless raw=true, in which case call f(m) (guaranteed if expose has not been overloaded.) If stop_if_true is true, then trigger an early stop if the value returned by f is true, logging the stop_message if specified.

source
IterationControl.WithNumberDoType
WithNumberDo(f=n->@info("number: $n"), stop_if_true=false, stop_message=nothing)

An iteration control, as in, WithNumberDo(n->put!(my_channel, n)).

Call f(n + 1), where n is the number of complete control cycles. of the control (so, n = 1, 2, 3, ..., unless control is wrapped in a IterationControl.skip)`.

If stop_if_true is true, then trigger an early stop if the value returned by f is true, logging the stop_message if specified.

source
MLJIteration.WithIterationsDoType
WithIterationsDo(f=x->@info("iterations: $x"), stop_if_true=false, stop_message=nothing)

An iteration control, as in, WithIterationsDo(x->put!(my_channel, x)).

Call f(x), where x is the current number of model iterations (generally more than the number of control cycles). If stop_if_true is true, then trigger an early stop if the value returned by f is true, logging the stop_message if specified.

source
IterationControl.WithLossDoType
WithLossDo(f=x->@info("loss: $x"), stop_if_true=false, stop_message=nothing)

An iteration control, as in, WithLossDo(x->put!(my_losses, x)).

Call f(loss), where loss is current loss.

If stop_if_true is true, then trigger an early stop if the value returned by f is true, logging the stop_message if specified.

source
IterationControl.WithTrainingLossesDoType
WithTrainingLossesDo(f=v->@info("training: $v"), stop_if_true=false, stop_message=nothing)

An iteration control, as in, WithTrainingLossesDo(v->put!(my_losses, last(v)).

Call f(training_losses), where training_losses is the vector of most recent batch of training losses.

If stop_if_true is true, then trigger an early stop if the value returned by f is true, logging the stop_message if specified.

source
MLJIteration.WithEvaluationDoType
WithEvaluationDo(f=x->@info("evaluation: $x"), stop_if_true=false, stop_message=nothing)

An iteration control, as in, WithEvaluationDo(x->put!(my_channel, x)).

Call f(x), where x is the latest performance evaluation, as returned by evaluate!(train_mach, resampling=..., ...). Not valid if resampling=nothing. If stop_if_true is true, then trigger an early stop if the value returned by f is true, logging the stop_message if specified.

source
MLJIteration.WithFittedParamsDoType
WithFittedParamsDo(f=x->@info("fitted_params: $x"), stop_if_true=false, stop_message=nothing)

An iteration control, as in, WithFittedParamsDo(x->put!(my_channel, x)).

Call f(x), where x = fitted_params(mach) is the fitted parameters of the training machine, mach, in its current state. If stop_if_true is true, then trigger an early stop if the value returned by f is true, logging the stop_message if specified.

source
MLJIteration.WithReportDoType
WithReportDo(f=x->@info("report: $x"), stop_if_true=false, stop_message=nothing)

An iteration control, as in, WithReportDo(x->put!(my_channel, x)).

Call f(x), where x = report(mach) is the report associated with the training machine, mach, in its current state. If stop_if_true is true, then trigger an early stop if the value returned by f is true, logging the stop_message if specified.

source
MLJIteration.WithModelDoType
WithModelDo(f=x->@info("model: $x"), stop_if_true=false, stop_message=nothing)

An iteration control, as in, WithModelDo(x->put!(my_channel, x)).

Call f(x), where x is the model associated with the training machine; f may mutate x, as in f(x) = (x.learning_rate *= 0.9). If stop_if_true is true, then trigger an early stop if the value returned by f is true, logging the stop_message if specified.

source
MLJIteration.WithMachineDoType
WithMachineDo(f=x->@info("machine: $x"), stop_if_true=false, stop_message=nothing)

An iteration control, as in, WithMachineDo(x->put!(my_channel, x)).

Call f(x), where x is the training machine in its current state. If stop_if_true is true, then trigger an early stop if the value returned by f is true, logging the stop_message if specified.

source
MLJIteration.SaveType
Save(filename="machine.jls")

An iteration control, as in, Save("run3/machine.jls").

Save the current state of the machine being iterated to disk, using the provided filename, decorated with a number, as in "run3/machine42.jls". The default behaviour uses the Serialization module but this can be changed by setting the method=save_fn(::String, ::Any) argument where save_fn is any serialization method. For more on what is meant by "the machine being iterated", see IteratedModel.

source

Control wrappers

IterationControl.skipFunction
IterationControl.skip(control, predicate=1)

An iteration control wrapper.

If predicate is an integer, k: Apply control on every k calls to apply the wrapped control, starting with the kth call.

If predicate is a function: Apply control as usual when predicate(n + 1) is true but otherwise skip. Here n is the number of control cycles applied so far.

source
IterationControl.louderFunction
IterationControl.louder(control, by=1)

Wrap control to make in more (or less) verbose. The same as control, but as if the global verbosity were increased by the value by.

source
IterationControl.with_state_doFunction
IterationControl.with_state_do(control,
                              f=x->@info "$(typeof(control)) state: $x")

Wrap control to give access to it's internal state. Acts exactly like control except that f is called on the internal state of control. If f is not specified, the control type and state are logged to Info at every update (useful for debugging new controls).

Warning. The internal state of a control is not yet considered part of the public interface and could change between in any pre 1.0 release of IterationControl.jl.

source