Quick start guide
The following are condensed and informal instructions for implementing the MLJ model interface for a new machine learning model. We assume: (i) you have a Julia registered package YourPackage.jl
implementing some machine learning models; (ii) that you would like to interface and register these models with MLJ; and (iii) that you have a rough understanding of how things work with MLJ. In particular, you are familiar with:
what scientific types are
what
Probabilistic
,Deterministic
andUnsupervised
models arethe fact that MLJ generally works with tables rather than matrices. Here a table is a container
X
satisfying the Tables.jl API and satisfyingTables.istable(X) == true
(e.g., DataFrame, JuliaDB table, CSV file, named tuple of equal-length vectors)CategoricalArrays.jl, if working with finite discrete data, e.g., doing classification; see also the Working with Categorical Data section of the MLJ manual.
If you're not familiar with any one of these points, the Getting Started section of the MLJ manual may help.
But tables don't make sense for my model! If a case can be made that tabular input does not make sense for your particular model, then MLJ can still handle this; you just need to define a non-tabular input_scitype
trait. However, you should probably open an issue to clarify the appropriate declaration. The discussion below assumes input data is tabular.
For simplicity, this document assumes no data front-end is to be defined for your model. Adding a data front-end, which offers the MLJ user some performance benefits, is easy to add post-facto, and is described in Implementing a data front-end.
Overview
To write an interface create a file or a module in your package which includes:
a
using MLJModelInterface
orimport MLJModelInterface
statementMLJ-compatible model types and constructors,
implementation of
fit
,predict
/transform
and optionallyfitted_params
for your models,metadata for your package and for each of your models
Important
MLJModelInterface is a very light-weight interface allowing you to define your interface, but does not provide the functionality required to use or test your interface; this requires MLJBase. So, while you only need to add MLJModelInterface
to your project's [deps], for testing purposes you need to add MLJBase to your project's [extras] and [targets]. In testing, simply use MLJBase
in place of MLJModelInterface
.
We give some details for each step below with, each time, a few examples that you can mimic. The instructions are intentionally brief.
Model type and constructor
MLJ-compatible constructors for your models need to meet the following requirements:
- be
mutable struct
, - be subtypes of
MLJModelInterface.Probabilistic
orMLJModelInterface.Deterministic
orMLJModelInterface.Unsupervised
, - have fields corresponding exclusively to hyperparameters,
- have a keyword constructor assigning default values to all hyperparameters.
You may use the @mlj_model
macro from MLJModelInterface
to declare a (non parametric) model type:
MLJModelInterface.@mlj_model mutable struct YourModel <: MLJModelInterface.Deterministic
a::Float64 = 0.5::(_ > 0)
b::String = "svd"::(_ in ("svd","qr"))
end
That macro specifies:
- A keyword constructor (here
YourModel(; a=..., b=...)
), - Default values for the hyperparameters,
- Constraints on the hyperparameters where
_
refers to a value passed.
Further to the last point, a::Float64 = 0.5::(_ > 0)
indicates that the field a
is a Float64
, takes 0.5
as its default value, and expects its value to be positive.
Please see this issue for a known issue and workaround relating to the use of @mlj_model
with negative defaults.
If you decide not to use the @mlj_model
macro (e.g. in the case of a parametric type), you will need to write a keyword constructor and a clean!
method:
mutable struct YourModel <: MLJModelInterface.Deterministic
a::Float64
end
function YourModel(; a=0.5)
model = YourModel(a)
message = MLJModelInterface.clean!(model)
isempty(message) || @warn message
return model
end
function MLJModelInterface.clean!(m::YourModel)
warning = ""
if m.a <= 0
warning *= "Parameter `a` expected to be positive, resetting to 0.5"
m.a = 0.5
end
return warning
end
Additional notes:
Please annotate all fields with concrete types, if possible, using type parameters if necessary.
Please prefer
Symbol
overString
if you can (e.g. to pass the name of a solver).Please add constraints to your fields even if they seem obvious to you.
Your model may have 0 fields, that's fine.
Although not essential, try to avoid Union types for model fields. For example, a field declaration
features::Vector{Symbol}
with a default ofSymbol[]
(detected with theisempty
method) is preferred tofeatures::Union{Vector{Symbol}, Nothing}
with a default ofnothing
.
Examples:
- KNNClassifier which uses
@mlj_model
, - XGBoostRegressor which does not.
Fit
The implementation of fit
will look like
function MLJModelInterface.fit(m::YourModel, verbosity, X, y, w=nothing)
# body ...
return (fitresult, cache, report)
end
where y
should only be there for a supervised model and w
for a supervised model that supports sample weights. You must type verbosity
to Int
and you must not type X
, y
and w
(MLJ handles that).
Regressor
In the body of the fit
function, you should assume that X
is a table and that y
is an AbstractVector
(for multitask regression it may be a table).
Typical steps in the body of the fit
function will be:
forming a matrix-view of the data, possibly transposed if your model expects a
p x n
formalism (MLJ assumes columns are features by default i.e.n x p
), useMLJModelInterface.matrix
for this,passing the data to your model,
returning the results as a tuple
(fitresult, cache, report)
.
The fitresult
part should contain everything that is needed at the predict
or transform
step, it should not be expected to be accessed by users. The cache
should be left to nothing
for now. The report
should be a NamedTuple
with any auxiliary useful information that a user would want to know about the fit (e.g., feature rankings). See more on this below.
Example: GLM's LinearRegressor
Classifier
For a classifier, the steps are fairly similar to a regressor with these differences:
y
will be a categorical vector and you will typically want to use the integer encoding ofy
instead ofCategoricalValue
s; useMLJModelInterface.int
for this.- You will need to pass the full pool of target labels (not just those observed in the training data) and additionally, in the
Deterministic
case, the encoding, to make these available topredict
. A simple way to do this is to passy[1]
in thefitresult
, for thenMLJModelInterface.classes(y[1])
is a complete list of possible categorical elements, andd = MLJModelInterface.decoder(y[1])
is a method for recovering categorical elements from their integer representations (e.g.,d(2)
is the categorical element with2
as encoding). - In the case of a probabilistic classifier you should pass all probabilities simultaneously to the
UnivariateFinite
constructor to get an abstractUnivariateFinite
vector (typeUnivariateFiniteArray
) rather than use comprehension or broadcasting to get a vanilla vector. This is for performance reasons.
If implementing a classifier, you should probably consult the more detailed instructions at The predict method.
Examples:
GLM's BinaryClassifier (
Probabilistic
)LIBSVM's SVC (
Deterministic
)
Transformer
Nothing special for a transformer.
Example: FillImputer
Fitted parameters
There is a function you can optionally implement which will return the learned parameters of your model for user inspection. For instance, in the case of a linear regression, the user may want to get direct access to the coefficients and intercept. This should be as human and machine-readable as practical (not a graphical representation) and the information should be combined in the form of a named tuple.
The function will always look like:
function MLJModelInterface.fitted_params(model::YourModel, fitresult)
# extract what's relevant from `fitresult`
# ...
# then return as a NamedTuple
return (learned_param1 = ..., learned_param2 = ...)
end
Example: for GLM models
Summary of user interface points (or, What to put where?)
Recall that the fitresult
returned as part of fit
represents everything needed by predict
(or transform
) to make new predictions. It is not intended to be directly inspected by the user. Here is a summary of the interface points for users that your implementation creates:
- Use
fitted_params
to expose learned parameters, such as linear coefficients, to the user in a machine and human-readable form (for re-use in another model, for example). - Use the fields of your model struct for hyperparameters, i.e., those parameters declared by the user ahead of time that generally affect the outcome of training. It is okay to add "control" parameters (such as specifying an
acceleration
parameter specifying computational resources, as here). - Use
report
to return everything else, including model-specific methods (or other callable objects). This includes feature rankings, decision boundaries, SVM support vectors, clustering centres, methods for visualizing training outcomes, methods for saving learned parameters in a custom format, degrees of freedom, deviance, etc. If there is a performance cost to extra functionality you want to expose, the functionality can be toggled on/off through a hyperparameter, but this should otherwise be avoided. For, example, in a decision tree modelreport.print_tree(depth)
might generate a pretty tree representation of the learned tree, up to the specifieddepth
.
Predict/Transform
The implementation of predict
(for a supervised model) or transform
(for an unsupervised one) will look like:
function MLJModelInterface.predict(m::YourModel, fitresult, Xnew)
# ...
end
Here Xnew
is expected to be a table and part of the logic in predict
or transform
may be similar to that in fit
.
The values returned should be:
model subtype | return value of predict/transform |
---|---|
Deterministic | vector of values (or table if multi-target) |
Probabilistic | vector of Distribution objects, for classifiers in particular, a vector of UnivariateFinite |
Unsupervised | table |
In the case of a Probabilistic
model, you may further want to implement a predict_mean
or a predict_mode
. However, MLJModelInterface provides fallbacks, defined in terms of predict
, whose performance may suffice.
Examples
- Deterministic regression: KNNRegressor
- Probabilistic regression: LinearRegressor and the
predict_mean
- Probabilistic classification: LogisticClassifier
Metadata (traits)
Adding metadata for your model(s) is crucial for the discoverability of your package and its models and to make sure your model is used with data it can handle. You can individually overload a number of trait functions that encode this metadata by following the instructions in Adding Models for General Use), which also explains these traits in more detail. However, your most convenient option is to use metadata_model
and metadata_pkg
functionalities from MLJModelInterface
to do this:
const ALL_MODELS = Union{YourModel1, YourModel2, ...}
MLJModelInterface.metadata_pkg.(ALL_MODELS
name = "YourPackage",
uuid = "6ee0df7b-...", # see your Project.toml
url = "https://...", # URL to your package repo
julia = true, # is it written entirely in Julia?
license = "MIT", # your package license
is_wrapper = false, # does it wrap around some other package?
)
# Then for each model,
MLJModelInterface.metadata_model(YourModel1,
input_scitype = MLJModelInterface.Table(MLJModelInterface.Continuous), # what input data is supported?
target_scitype = AbstractVector{MLJModelInterface.Continuous}, # for a supervised model, what target?
output_scitype = MLJModelInterface.Table(MLJModelInterface.Continuous), # for an unsupervised, what output?
supports_weights = false, # does the model support sample weights?
descr = "A short description of your model"
load_path = "YourPackage.SubModuleContainingModelStructDefinition.YourModel1"
)
Important. Do not omit the load_path
specification. Without a correct load_path
MLJ will be unable to import your model.
Examples:
- package metadata
- model metadata