Machines

Under the hood, calling fit! on a machine calls either MLJBase.fit or MLJBase.update, depending on the machine's internal state, as recorded in additional fields previous_model and previous_rows. These lower-level fit and update methods dispatch on the model and a view of the data defined by the optional rows keyword argument of fit! (all rows by default). In this way, if a model update method is implemented, calls to fit! can avoid redundant calculations for certain kinds of model mutations (eg, increasing the number of epochs in a neural network).

The interested reader can learn more on machine internals by examining the simplified code excerpt in Internals.

forest = EnsembleModel(atom=(@load DecisionTreeClassifier), n=10);
X, y = @load_iris;
mach = machine(forest, X, y)
fit!(mach, verbosity=2);
Machine{ProbabilisticEnsembleModel{DecisionTreeClassifier}} @ 1…24

Generally, changing a hyperparameter triggers retraining on calls to subsequent fit!:

julia> forest.bagging_fraction=0.5
0.5

julia> fit!(mach, verbosity=2);
[ Info: Updating Machine{ProbabilisticEnsembleModel{DecisionTreeClassifier}} @ 1…24.
[ Info: Truncating existing ensemble.

However, for this iterative model, increasing the iteration parameter only adds models to the existing ensemble:

julia> forest.n=15
15

julia> fit!(mach, verbosity=2);
[ Info: Updating Machine{ProbabilisticEnsembleModel{DecisionTreeClassifier}} @ 1…24.
[ Info: Building on existing ensemble of length 10
[ Info: One hash per new atom trained: 
#####

Call fit! again without making a change and no retraining occurs:

julia> fit!(mach);
┌ Info: Not retraining Machine{ProbabilisticEnsembleModel{DecisionTreeClassifier}} @ 1…24.
└  It appears up-to-date. Use `force=true` to force retraining.

However, retraining can be forced:

julia> fit!(mach, force=true);
[ Info: Training Machine{ProbabilisticEnsembleModel{DecisionTreeClassifier}} @ 1…24.

And is re-triggered if the view of the data changes:

julia> fit!(mach, rows=1:100);
[ Info: Training Machine{ProbabilisticEnsembleModel{DecisionTreeClassifier}} @ 1…24.
julia> fit!(mach, rows=1:100);
┌ Info: Not retraining Machine{ProbabilisticEnsembleModel{DecisionTreeClassifier}} @ 1…24.
└  It appears up-to-date. Use `force=true` to force retraining.

Inspecting machines

There are two methods for inspecting the outcomes of training in MLJ. To obtain a named-tuple describing the learned parameters (in a user-friendly way where possible) use fitted_params(mach). All other training-related outcomes are inspected with report(mach).

X, y = @load_iris
pca = @load PCA
mach = machine(pca, X)
fit!(mach)
Machine{PCA} @ 1…80
julia> fitted_params(mach)
(projection = PCA(indim = 4, outdim = 3, principalratio = 0.99481691454981),)

julia> report(mach)
(indim = 4,
 outdim = 3,
 mean = [5.8433333333333355, 3.054000000000001, 3.7586666666666697, 1.1986666666666674],
 principalvars = [4.224840768320109, 0.24224357162751542, 0.07852390809415459],
 tprincipalvar = 4.545608248041779,
 tresidualvar = 0.02368302712600201,
 tvar = 4.569291275167781,)

Saving machines

To save a machine to file, use the MLJ.save command:

tree = @load DecisionTreeClassifier
mach = fit!(machine(tree, X, y))
MLJ.save("my_machine.jlso", mach)

To de-serialize, one uses the machine constructor:

mach2 = machine("my_machine.jlso")
predict(mach2, Xnew);

The machine mach2 cannot be retrained; however, by providing data to the constructor one can enable retraining using the saved model hyperparameters (which overwrites the saved learned parameters):

mach3 = machine("my_machine.jlso", Xnew, ynew)
fit!(mach3)

Internals

For a supervised machine the predict method calls a lower-level MLJBase.predict method, dispatched on the underlying model and the fitresult (see below). To see predict in action, as well as its unsupervised cousins transform and inverse_transform, see Getting Started.

The fields of a Machine instance (which should not generally be accessed byt the user) are:

  • model - the struct containing the hyperparameters to be used in calls to fit!

  • fitresult - the learned parameters in a raw form, initially undefined

  • args - a tuple of the data (in the supervised learning example above, args = (X, y))

  • report - outputs of training not encoded in fitresult (eg, feature rankings)

  • previous_model - a deep copy of the model used in the last call to fit!

  • previous_rows - a copy of the row indices used in last call to fit!

  • cache

Instead of data X and y, the machine constructor can be provided Node or Source objects ("dynamic data") to obtain a NodalMachine, rather than a regular Machine object, which has the fields listed above and some others. See Composing Models for more on this advanced feature.

API Reference

StatsBase.fit!Function
fit!(mach::Machine; rows=nothing, verbosity=1, force=false)

When called for the first time, call

fit(mach.model, verbosity, mach.args...)

storing the returned fit-result and report in mach. Subsequent calls do nothing unless: (i) force=true, or (ii) the specified rows are different from those used the last time a fit-result was computed, or (iii) mach.model has changed since the last time a fit-result was computed (the machine is stale). In cases (i) or (ii) MLJBase.fit is called again. Otherwise, MLJBase.update is called.

fit!(mach::NodalMachine; rows=nothing, verbosity=1, force=false)

When called for the first time, attempt to call

fit(mach.model, verbosity, mach.args...)

This will fail if an argument of the machine depends ultimately on some other untrained machine for successful calling, but this is resolved by instead calling fit! any node N for which mach in machines(N) is true, which trains all necessary machines in an appropriate order. Subsequent fit! calls do nothing unless: (i) force=true, or (ii) some machine on which mach depends has computed a new fit-result since mach last computed its fit-result, or (iii) the specified rows have changed since the last time a fit-result was last computed, or (iv) mach is stale (see below). In cases (i), (ii) or (iii), MLJBase.fit is called. Otherwise MLJBase.update is called.

A machine mach is stale if mach.model has changed since the last time a fit-result was computed, or if one of its training arguments is stale. A node N is stale if N.machine is stale or one of its arguments is stale. Source nodes are never stale.

Note that a nodal machine obtains its training data by calling its node arguments on the specified rows (rather than indexing its arguments on those rows) and that this calling is a recursive operation on nodes upstream of those arguments.

fit!(N::Node; rows=nothing, verbosity::Int=1, force::Bool=false)

Train all machines in the learning network terminating at node N, in an appropriate order. These machines are those returned by machines(N).

MLJModelInterface.saveFunction
MLJ.save(filename, mach::AbstractMachine; kwargs...)
MLJ.save(io, mach::Machine; kwargs...)

MLJBase.save(filename, mach::AbstractMachine; kwargs...)
MLJBase.save(io, mach::Machine; kwargs...)

Serialize the machine mach to a file with path filename, or to an input/output stream io (at least IOBuffer instances are supported).

The format is JLSO (a wrapper for julia native or BSON serialization) unless a custom format has been implemented for the model type of mach.model. The keyword arguments kwargs are passed to the format-specific serializer, which in the JSLO case include these:

keywordvaluesdefault
format:julia_serialize, :BSON:julia_serialize
compression:gzip, :none:none

See (see https://github.com/invenia/JLSO.jl for details.

Machines are de-serialized using the machine constructor as shown in the example below. Data (or nodes) may be optionally passed to the constructor for retraining on new data using the saved model.

Example

using MLJ
tree = @load DecisionTreeClassifier
X, y = @load_iris
mach = fit!(machine(tree, X, y))

MLJ.save("tree.jlso", mach, compression=:none)
mach_predict_only = machine("tree.jlso")
predict(mach_predict_only, X)

mach2 = machine("tree.jlso", selectrows(X, 1:100), y[1:100])
predict(mach2, X) # same as above

fit!(mach2) # saved learned parameters are over-written
predict(mach2, X) # not same as above

# using a buffer:
io = IOBuffer()
MLJ.save(io, mach)
seekstart(io)
predict_only_mach = machine(io)
predict(predict_only_mach, X)
Only load files from trusted sources

Maliciously constructed JLSO files, like pickles, and most other general purpose serialization formats, can allow for arbitrary code execution during loading. This means it is possible for someone to use a JLSO file that looks like a serialized MLJ machine as a Trojan horse.