Reference

Here we give the definitive specification of the LearnAPI.jl interface. For informal guides see Anatomy of an Implementation and Common Implementation Patterns.

Important terms and concepts

The LearnAPI.jl specification is predicated on a few basic, informally defined notions:

Data and observations

ML/statistical algorithms are typically applied in conjunction with resampling of observations, as in cross-validation. In this document data will always refer to objects encapsulating an ordered sequence of individual observations.

A DataFrame instance, from DataFrames.jl, is an example of data, the observations being the rows. Typically, data provided to LearnAPI.jl algorithms, will implement the MLUtils.jl getobs/numobs interface for accessing individual observations, but implementations can opt out of this requirement; see obs and LearnAPI.data_interface for details.

Note

In the MLUtils.jl convention, observations in tables are the rows but observations in a matrix are the columns.

Hyperparameters

Besides the data it consumes, a machine learning algorithm's behavior is governed by a number of user-specified hyperparameters, such as the number of trees in a random forest. In LearnAPI.jl, one is allowed to have hyperparameters that are not data-generic. For example, a class weight dictionary, which will only make sense for a target taking values in the set of dictionary keys, can be specified as a hyperparameter.

Targets and target proxies

Context

After training, a supervised classifier predicts labels on some input which are then compared with ground truth labels using some accuracy measure, to assesses the performance of the classifier. Alternatively, the classifier predicts class probabilities, which are instead paired with ground truth labels using a proper scoring rule, say. In outlier detection, "outlier"/"inlier" predictions, or probability-like scores, are similarly compared with ground truth labels. In clustering, integer labels assigned to observations by the clustering algorithm can can be paired with human labels using, say, the Rand index. In survival analysis, predicted survival functions or probability distributions are compared with censored ground truth survival times. And so on ...

Definitions

More generally, whenever we have a variable (e.g., a class label) that can, at least in principle, be paired with a predicted value, or some predicted "proxy" for that variable (such as a class probability), then we call the variable a target variable, and the predicted output a target proxy. In this definition, it is immaterial whether or not the target appears in training (the algorithm is supervised) or whether or not predictions generalize to new input observations (the algorithm "learns").

LearnAPI.jl provides singleton target proxy types for prediction dispatch. These are also used to distinguish performance metrics provided by the package StatisticalMeasures.jl.

Learners

An object implementing the LearnAPI.jl interface is called a learner, although it is more accurately "the configuration of some machine learning or statistical algorithm".¹ A learner encapsulates a particular set of user-specified hyperparameters as the object's properties (which conceivably differ from its fields). It does not store learned parameters.

Informally, we will sometimes use the word "model" to refer to the output of fit(learner, ...) (see below), something which typically does store learned parameters.

For learner to be a valid LearnAPI.jl learner, LearnAPI.constructor(learner) must be defined and return a keyword constructor enabling recovery of learner from its properties:

properties = propertynames(learner)
named_properties = NamedTuple{properties}(getproperty.(Ref(learner), properties))
@assert learner == LearnAPI.constructor(learner)(; named_properties...)

which can be tested with @assertLearnAPI.clone(learner)== learner.

Note that if if learner is an instance of a mutable struct, this requirement generally requires overloading Base.== for the struct.

Important

No LearnAPI.jl method is permitted to mutate a learner. In particular, one should make deep copies of RNG hyperparameters before using them in a new implementation of fit.

Composite learners (wrappers)

A composite learner is one with at least one property that can take other learners as values; for such learners LearnAPI.is_composite(learner) must be true (fallback is false). Generally, the keyword constructor provided by LearnAPI.constructor must provide default values for all properties that are not learner-valued. Instead, these learner-valued properties can have a nothing default, with the constructor throwing an error if the constructor call does not explicitly specify a new value.

Any object learner for which LearnAPI.functions(learner) is non-empty is understood to have a valid implementation of the LearnAPI.jl interface.

Example

Below is an example of a learner type with a valid constructor:

struct GradientRidgeRegressor{T<:Real}
    learning_rate::T
    epochs::Int
    l2_regularization::T
end
GradientRidgeRegressor(; learning_rate=0.01, epochs=10, l2_regularization=0.01) =
    GradientRidgeRegressor(learning_rate, epochs, l2_regularization)
LearnAPI.constructor(::GradientRidgeRegressor) = GradientRidgeRegressor

Documentation

Attach public LearnAPI.jl-related documentation for a learner to it's constructor, rather than to the struct defining its type. In this way, a learner can implement multiple interfaces, in addition to the LearnAPI interface, with separate document strings for each.

Methods

Compulsory methods

All new learner types must implement fit, LearnAPI.learner, LearnAPI.constructor and LearnAPI.functions.

Most learners will also implement predict and/or transform. For a minimal (but useless) implementation, see the implementation of SmallLearner here.

List of methods

  • fit: for (i) training learners that generalize to new data; or (ii) wrapping learner in an object that is possibly mutated by predict/transform, to record byproducts of those operations, in the special case of non-generalizing learners (called here static algorithms)

  • update: for updating learning outcomes after hyperparameter changes, such as increasing an iteration parameter.

  • update_observations, update_features: update learning outcomes by presenting additional training data.

  • predict: for outputting targets or target proxies (such as probability density functions)

  • transform: similar to predict, but for arbitrary kinds of output, and which can be paired with an inverse_transform method

  • inverse_transform: for inverting the output of transform ("inverting" broadly understood)

  • obs: method for exposing to the user learner-specific representations of data, which are additionally guaranteed to implement the observation access API specified by LearnAPI.data_interface(learner).

  • LearnAPI.target, LearnAPI.weights, LearnAPI.features: for extracting relevant parts of training data, where defined.

  • Accessor functions: these include functions like LearnAPI.feature_importances and LearnAPI.training_losses, for extracting, from training outcomes, information common to many learners. This includes LearnAPI.strip(model) for replacing a learning outcome model with a serializable version that can still predict or transform.

  • Learner traits: methods that promise specific learner behavior or record general information about the learner. Only LearnAPI.constructor and LearnAPI.functions are universally compulsory.

Utilities

LearnAPI.cloneFunction
LearnAPI.clone(learner; replacements...)

Return a shallow copy of learner with the specified hyperparameter replacements.

clone(learner; epochs=100, learning_rate=0.01)

A LearnAPI.jl contract ensures that LearnAPI.clone(learner) == learner.

source
LearnAPI.@traitMacro
@trait(LearnerType, trait1=value1, trait2=value2, ...)

Overload a number of traits for learners of type LearnerType. For example, the code

@trait(
    RidgeRegressor,
    tags = ("regression", ),
    doc_url = "https://some.cool.documentation",
)

is equivalent to

LearnAPI.tags(::RidgeRegressor) = ("regression", ),
LearnAPI.doc_url(::RidgeRegressor) = "https://some.cool.documentation",
source

¹ We acknowledge users may not like this terminology, and may know "learner" by some other name, such as "strategy", "options", "hyperparameter set", "configuration", "algorithm", or "model". Consensus on this point is difficult; see, e.g., this Julia Discourse discussion.