Simple User Defined Models
To quickly implement a new supervised model in MLJ, it suffices to:
Define a
mutable struct
to store hyperparameters. This is either a subtype ofProbabilistic
orDeterministic
, depending on whether probabilistic or ordinary point predictions are intended. Thisstruct
is the model.Define a
fit
method, dispatched on the model, returning learned parameters, also known as the fitresult.Define a
predict
method, dispatched on the model, and the fitresult, to return predictions on new patterns.
In the examples below, the training input X
of fit
, and the new input Xnew
passed to predict
, are tables. Each training target y
is an AbstractVector
.
The predictions returned by predict
have the same form as y
for deterministic models, but are Vector
s of distributions for probabilistic models.
Advanced model functionality not addressed here includes: (i) optional update
method to avoid redundant calculations when calling fit!
on machines a second time; (ii) reporting extra training-related statistics; (iii) exposing model-specific functionality; (iv) checking the scientific type of data passed to your model in machine
construction; and (iv) checking the validity of hyperparameter values. All this is described in Adding Models for General Use.
For an unsupervised model, implement transform
and, optionally, inverse_transform
using the same signature at predict
below.
A simple deterministic regressor
Here's a quick-and-dirty implementation of a ridge regressor with no intercept:
import MLJBase
using LinearAlgebra
mutable struct MyRegressor <: MLJBase.Deterministic
lambda::Float64
end
MyRegressor(; lambda=0.1) = MyRegressor(lambda)
# fit returns coefficients minimizing a penalized rms loss function:
function MLJBase.fit(model::MyRegressor, verbosity, X, y)
x = MLJBase.matrix(X) # convert table to matrix
fitresult = (x'x + model.lambda*I)\(x'y) # the coefficients
cache = nothing
report = nothing
return fitresult, cache, report
end
# predict uses coefficients to make a new prediction:
MLJBase.predict(::MyRegressor, fitresult, Xnew) = MLJBase.matrix(Xnew) * fitresult
After loading this code, all MLJ's basic meta-algorithms can be applied to MyRegressor
:
julia> X, y = @load_boston;
julia> model = MyRegressor(lambda=1.0)
MyRegressor( lambda = 1.0)
julia> regressor = machine(model, X, y)
untrained Machine; caches model-specific representations of data model: MyRegressor(lambda = 1.0) args: 1: Source @332 ⏎ Table{AbstractVector{Continuous}} 2: Source @987 ⏎ AbstractVector{Continuous}
julia> evaluate!(regressor, resampling=CV(), measure=rms, verbosity=0)
PerformanceEvaluation object with these fields: model, measure, operation, measurement, per_fold, per_observation, fitted_params_per_fold, report_per_fold, train_test_rows, resampling, repeats Extract: ┌────────────────────────┬───────────┬─────────────┐ │ measure │ operation │ measurement │ ├────────────────────────┼───────────┼─────────────┤ │ RootMeanSquaredError() │ predict │ 5.94 │ └────────────────────────┴───────────┴─────────────┘ ┌──────────────────────────────────────┬─────────┐ │ per_fold │ 1.96*SE │ ├──────────────────────────────────────┼─────────┤ │ [2.71, 4.44, 5.06, 3.47, 11.0, 5.13] │ 2.58 │ └──────────────────────────────────────┴─────────┘
A simple probabilistic classifier
The following probabilistic model simply fits a probability distribution to the MultiClass
training target (i.e., ignores X
) and returns this pdf for any new pattern:
import MLJBase
import Distributions
struct MyClassifier <: MLJBase.Probabilistic
end
# `fit` ignores the inputs X and returns the training target y
# probability distribution:
function MLJBase.fit(model::MyClassifier, verbosity, X, y)
fitresult = Distributions.fit(MLJBase.UnivariateFinite, y)
cache = nothing
report = nothing
return fitresult, cache, report
end
# `predict` returns the passed fitresult (pdf) for all new patterns:
MLJBase.predict(model::MyClassifier, fitresult, Xnew) =
[fitresult for r in 1:nrows(Xnew)]
julia> X, y = @load_iris;
julia> mach = machine(MyClassifier(), X, y) |> fit!;
[ Info: Training machine(MyClassifier(), …).
julia> predict(mach, selectrows(X, 1:2))
2-element Vector{UnivariateFinite{Multiclass{3}, String, UInt32, Float64}}: UnivariateFinite{Multiclass{3}}(setosa=>0.333, versicolor=>0.333, virginica=>0.333) UnivariateFinite{Multiclass{3}}(setosa=>0.333, versicolor=>0.333, virginica=>0.333)