Here we give the definitive specification of the LearnAPI.jl interface. For informal guides see Anatomy of an Implementation and Common Implementation Patterns.
Important terms and concepts
The LearnAPI.jl specification is predicated on a few basic, informally defined notions:
Data and observations
ML/statistical algorithms are frequently applied in conjunction with resampling of observations, as in cross-validation. In this document data will always refer to objects encapsulating an ordered sequence of individual observations.
A DataFrame
instance, from DataFrames.jl, is an example of data, the observations being the rows. Typically, data provided to LearnAPI.jl algorithms, will implement the MLCore.jl getobs/numobs
interface for accessing individual observations, but implementations can opt out of this requirement; see obs
and LearnAPI.data_interface
for details.
In the MLCore.jl convention, observations in tables are the rows but observations in a matrix are the columns.
Besides the data it consumes, a machine learning algorithm's behavior is governed by a number of user-specified hyperparameters, such as the number of trees in a random forest. Hyperparameters are understood in a rather broad sense. For example, one is allowed to have hyperparameters that are not data-generic. For example, a class weight dictionary, which will only make sense for a target taking values in the set of specified dictionary keys, should be given as a hyperparameter. For simplicity and composability, LearnAPI.jl discourages "run time" parameters (extra arguments to fit
) such as acceleration options (cpu/gpu/multithreading/multiprocessing). These should be included as hyperparameters as far as possible. An exception is the compulsory verbosity
keyword argument of fit
Targets and target proxies
After training, a supervised classifier predicts labels on some input which are then compared with ground truth labels using some accuracy measure, to assess the performance of the classifier. Alternatively, the classifier predicts class probabilities, which are instead paired with ground truth labels using a proper scoring rule, say. In outlier detection, "outlier"/"inlier" predictions, or probability-like scores, are similarly compared with ground truth labels. In clustering, integer labels assigned to observations by the clustering algorithm can can be paired with human labels using, say, the Rand index. In survival analysis, predicted survival functions or probability distributions are compared with censored ground truth survival times. And so on ...
More generally, whenever we have a variable that can, at least in principle, be paired with a predicted value, or some predicted "proxy" for that variable (such as a class probability), then we call the variable a target variable, and the predicted output a target proxy. In this definition, it is immaterial whether or not the target appears in training (the algorithm is supervised) or whether or not predictions generalize to new input observations (the algorithm "learns").
LearnAPI.jl provides singleton target proxy types for prediction dispatch. These are the same types used to distinguish performance metrics provided by the package StatisticalMeasures.jl.
An object implementing the LearnAPI.jl interface is called a learner, although it is more accurately "the configuration of some machine learning or statistical algorithm".¹ A learner encapsulates a particular set of user-specified hyperparameters as the object's properties (which conceivably differ from its fields). It does not store learned parameters.
Informally, we will sometimes use the word "model" to refer to the output of fit(learner, ...)
(see below), something which typically does store learned parameters.
For every learner
, LearnAPI.constructor(learner)
must return a keyword constructor enabling recovery of learner
from its properties:
properties = propertynames(learner)
named_properties = NamedTuple{properties}(getproperty.(Ref(learner), properties))
@assert learner == LearnAPI.constructor(learner)(; named_properties...)
which can be tested with @assert
== learner
Note that if learner
is an instance of a mutable struct, this requirement generally requires overloading Base.==
for the struct.
No LearnAPI.jl method is permitted to mutate a learner. In particular, one should make deep copies of RNG hyperparameters before using them in an implementation of fit
Composite learners (wrappers)
A composite learner is one with at least one property that can take other learners as values; for such learners LearnAPI.learners(learner)
is non-empty. A keyword constructor provided by LearnAPI.constructor
must provide default values for all properties that are not in LearnAPI.learners(learner)
. Instead, these learner-valued properties can have a nothing
default, with the constructor throwing an error if the constructor call does not explicitly specify a new value.
Below is an example of a learner type with a valid constructor:
struct GradientRidgeRegressor{T<:Real}
GradientRidgeRegressor(; learning_rate=0.01, epochs=10, l2_regularization=0.01)
Instantiate a gradient ridge regressor with the specified hyperparameters.
GradientRidgeRegressor(; learning_rate=0.01, epochs=10, l2_regularization=0.01) =
GradientRidgeRegressor(learning_rate, epochs, l2_regularization)
LearnAPI.constructor(::GradientRidgeRegressor) = GradientRidgeRegressor
Testing something is a learner
Any object object
for which LearnAPI.functions(object)
is non-empty is understood to have a valid implementation of the LearnAPI.jl interface. You can test this with the convenience method LearnAPI.is_learner(object)
but this is never explicitly overloaded.
Attach public LearnAPI.jl-related documentation for a learner to it's constructor, rather than to the struct defining its type, as shown in the example above. (In this way, multiple interfaces can share a common struct, with separate document strings for each interface.)
All new learner types must implement fit
, LearnAPI.learner
, LearnAPI.constructor
and LearnAPI.functions
Most learners will also implement predict
and/or transform
List of methods
: for (i) training learners that generalize to new data; or (ii) wrappinglearner
in an object that is possibly mutated bypredict
, to record byproducts of those operations, in the special case of non-generalizing learners (called here static algorithms)update
: for updating learning outcomes after hyperparameter changes, such as increasing an iteration parameter.update_observations
: update learning outcomes by presenting additional training data.predict
: for outputting targets or target proxies (such as probability density functions)transform
: similar topredict
, but for arbitrary kinds of output, and which can be paired with aninverse_transform
: for inverting the output oftransform
("inverting" broadly understood)obs
: method for exposing to the user learner-specific representations of data, which are additionally guaranteed to implement the observation access API specified byLearnAPI.data_interface(learner)
: for extracting relevant parts of training data, where defined. Also called training data deconstructors.Accessor functions: these include functions like
, for extracting, from training outcomes, information common to many learners. This includesLearnAPI.strip(model)
for replacing a learning outcome,model
, with a serializable version that can stillpredict
.Learner traits: methods that promise specific learner behavior or record general information about the learner. Only
are universally compulsory.
: for cloning a learner with specified hyperparameter replacements.@trait
: for simultaneously declaring multiple traits@functions
: for listing functions available for use with a learner
— FunctionLearnAPI.is_learner(object)
Returns true
if object
has a valid implementation of the LearnAPI.jl interface. Equivalent to non-emptiness of LearnAPI.functions(object)
This trait should never be overloaded explicitly.
— FunctionLearnAPI.clone(learner, replacements...)
LearnAPI.clone(learner; replacements...)
Return a shallow copy of learner
with the specified hyperparameter replacements. Two syntaxes are supported, as shown in the following examples:
clone(learner, :epochs => 100, :learner_rate => 0.01)
clone(learner; epochs=100, learning_rate=0.01)
A LearnAPI.jl contract ensures that LearnAPI.clone(learner) == learner
A new learner implementation does not overload clone
— Macro@trait(LearnerType, trait1=value1, trait2=value2, ...)
Simultaneously overload a number of traits for learners of type LearnerType
. For example, the code
tags = ("regression", ),
doc_url = "",
is equivalent to
LearnAPI.tags(::RidgeRegressor) = ("regression", ),
LearnAPI.doc_url(::RidgeRegressor) = "",
— Macro@functions learner
Return a tuple of functions that can be meaningfully applied with learner
, or an associated model, as the first argument. An "associated model" is an object returned by fit(learner, ...)
. Learner traits (methods for which learner
always the only argument) are excluded.
julia> @functions my_feature_selector
(fit, LearnAPI.learner, clone, strip, obs, transform)
New learner implementations should overload LearnAPI.functions
See also LearnAPI.functions
¹ We acknowledge users may not like this terminology, and may know "learner" by some other name, such as "strategy", "options", "hyperparameter set", "configuration", "algorithm", or "model". Consensus on this point is difficult; see, e.g., this Julia Discourse discussion.