Machines
Under the hood, calling fit!
on a machine calls either MLJBase.fit
or MLJBase.update
, depending on the machine's internal state, as recorded in additional fields previous_model
and previous_rows
. These lower-level fit
and update
methods dispatch on the model and a view of the data defined by the optional rows
keyword argument of fit!
(all rows by default). In this way, if a model update
method is implemented, calls to fit!
can avoid redundant calculations for certain kinds of model mutations (eg, increasing the number of epochs in a neural network).
The interested reader can learn more on machine internals by examining the simplified code excerpt in Internals.
forest = EnsembleModel(atom=(@load DecisionTreeClassifier), n=10);
X, y = @load_iris;
mach = machine(forest, X, y)
fit!(mach, verbosity=2);
Machine{ProbabilisticEnsembleModel{DecisionTreeClassifier}} @ 9…39
Generally, changing a hyperparameter triggers retraining on calls to subsequent fit!
:
julia> forest.bagging_fraction=0.5
0.5
julia> fit!(mach, verbosity=2);
[ Info: Updating Machine{ProbabilisticEnsembleModel{DecisionTreeClassifier}} @ 9…39.
[ Info: Truncating existing ensemble.
However, for this iterative model, increasing the iteration parameter only adds models to the existing ensemble:
julia> forest.n=15
15
julia> fit!(mach, verbosity=2);
[ Info: Updating Machine{ProbabilisticEnsembleModel{DecisionTreeClassifier}} @ 9…39.
[ Info: Building on existing ensemble of length 10
[ Info: One hash per new atom trained:
#####
Call fit!
again without making a change and no retraining occurs:
julia> fit!(mach);
┌ Info: Not retraining Machine{ProbabilisticEnsembleModel{DecisionTreeClassifier}} @ 9…39.
└ It appears up-to-date. Use `force=true` to force retraining.
However, retraining can be forced:
julia> fit!(mach, force=true);
[ Info: Training Machine{ProbabilisticEnsembleModel{DecisionTreeClassifier}} @ 9…39.
And is retriggered if the view of the data changes:
julia> fit!(mach, rows=1:100);
[ Info: Training Machine{ProbabilisticEnsembleModel{DecisionTreeClassifier}} @ 9…39.
julia> fit!(mach, rows=1:100);
┌ Info: Not retraining Machine{ProbabilisticEnsembleModel{DecisionTreeClassifier}} @ 9…39.
└ It appears up-to-date. Use `force=true` to force retraining.
For a supervised machine the predict
method calls a lower-level MLJBase.predict
method, dispatched on the underlying model and the fitresult
(see below). To see predict
in action, as well as its unsupervised cousins transform
and inverse_transform
, see Getting Started.
Here is a complete list of the fields of a machine:
model
- the struct containing the hyperparameters to be used in calls tofit!
fitresult
- the learned parameters in a raw form, initially undefinedargs
- a tuple of the data (in the supervised learning example above,args = (X, y)
)report
- outputs of training not encoded infitresult
(eg, feature rankings)previous_model
- a deep copy of the model used in the last call tofit!
previous_rows
- a copy of the row indices used in last call tofit!
cache
Instead of data X
and y
, the machine
constructor can be provided Node
or Source
objects ("dynamic data") to obtain a NodalMachine
, rather than a regular Machine
object, which includes the same fields listed above. See Composing Models for more on this advanced feature.
Inspecting machines
There are two methods for inspecting the outcomes of training in MLJ. To obtain a named-tuple describing the learned parameters, in a user-friendly way if possible, use fitted_params(mach)
. All other training-related outcomes are inspected with report(mach)
.
X, y = @load_iris
pca = @load PCA
mach = machine(pca, X)
fit!(mach)
Machine{PCA} @ 1…97
julia> fitted_params(mach)
(projection = PCA(indim = 4, outdim = 3, principalratio = 0.99481691454981),)
julia> report(mach)
(indim = 4,
outdim = 3,
mean = [5.8433333333333355, 3.054000000000001, 3.7586666666666697, 1.1986666666666674],
principalvars = [4.224840768320109, 0.24224357162751542, 0.07852390809415459],
tprincipalvar = 4.545608248041779,
tresidualvar = 0.02368302712600201,
tvar = 4.569291275167781,)
API Reference
StatsBase.fit!
— Function.fit!(mach::Machine; rows=nothing, verbosity=1, force=false)
When called for the first time, call
MLJBase.fit(mach.model, verbosity, mach.args...)
storing the returned fit-result and report in mach
. Subsequent calls do nothing unless: (i) force=true
, or (ii) the specified rows
are different from those used the last time a fit-result was computed, or (iii) mach.model
has changed since the last time a fit-result was computed (the machine is stale). In cases (i) or (ii) MLJBase.fit
is called again. Otherwise, MLJBase.update
is called.
fit!(mach::NodalMachine; rows=nothing, verbosity=1, force=false)
When called for the first time, attempt to call
MLJBase.fit(mach.model, verbosity, mach.args...)
This will fail if an argument of the machine depends ultimately on some other untrained machine for successful calling, but this is resolved by instead calling fit!
any node N
for which mach in machines(N)
is true, which trains all necessary machines in an appropriate order. Subsequent fit!
calls do nothing unless: (i) force=true
, or (ii) some machine on which mach
depends has computed a new fit-result since mach
last computed its fit-result, or (iii) the specified rows
have changed since the last time a fit-result was last computed, or (iv) mach
is stale (see below). In cases (i), (ii) or (iii), MLJBase.fit
is called. Otherwise MLJBase.update
is called.
A machine mach
is stale if mach.model
has changed since the last time a fit-result was computed, or if if one of its training arguments is stale
. A node N
is stale if N.machine
is stale or one of its arguments is stale. Source nodes are never stale.
Note that a nodal machine obtains its training data by calling its node arguments on the specified rows
(rather than indexing its arguments on those rows) and that this calling is a recursive operation on nodes upstream of those arguments.
fit!(N::Node; rows=nothing, verbosity::Int=1, force::Bool=false)
Train all machines in the learning network terminating at node N
, in an appropriate order. These machines are those returned by machines(N)
.