Anatomy of an Implementation

The core LearnAPI.jl pattern looks like this:

model = fit(learner, data)
predict(model, newdata)

Here learner specifies hyperparameters, while model stores learned parameters and any byproducts of algorithm execution.

Variations on this pattern:

These are the basic possibilities.

Elaborating on the core pattern above, this tutorial details an implementation of the LearnAPI.jl for naive ridge regression with no intercept. The kind of workflow we want to enable has been previewed in Sample workflow. Readers can also refer to the demonstration of the implementation given later.

A basic implementation

See here for code without explanations.

We suppose our algorithm's fit method consumes data in the form (X, y), where X is a suitable table¹ (the features) and y a vector (the target).

Important

Implementations wishing to support other data patterns may need to take additional steps explained under Other data patterns below.

The first line below imports the lightweight package LearnAPI.jl whose methods we will be extending. The second imports libraries needed for the core algorithm.

using LearnAPI
using LinearAlgebra, Tables

Defining learners

Here's a new type whose instances specify the single ridge regression hyperparameter:

struct Ridge{T<:Real}
    lambda::T
end

Instances of Ridge are learners, in LearnAPI.jl parlance.

Associated with each new type of LearnAPI.jl learner will be a keyword argument constructor, providing default values for all properties (typically, struct fields) that are not other learners, and we must implement LearnAPI.constructor(learner), for recovering the constructor from an instance:

"""
    Ridge(; lambda=0.1)

Instantiate a ridge regression learner, with regularization of `lambda`.
"""
Ridge(; lambda=0.1) = Ridge(lambda)
LearnAPI.constructor(::Ridge) = Ridge

For example, in this case, if learner = Ridge(0.2), then LearnAPI.constructor(learner)(lambda=0.2) == learner is true. Note that we attach the docstring to the constructor, not the struct.

Implementing fit

A ridge regressor requires two types of data for training: input features X, which here we suppose are tabular¹, and a target y, which we suppose is a vector.⁴

It is convenient to define a new type for the fit output, which will include coefficients labelled by feature name for inspection after training:

struct RidgeFitted{T,F}
    learner::Ridge
    coefficients::Vector{T}
    named_coefficients::F
end

Note that we also include learner in the struct, for it must be possible to recover learner from the output of fit; see Accessor functions below.

The implementation of fit looks like this:

function LearnAPI.fit(learner::Ridge, data; verbosity=1)
    X, y = data

    # data preprocessing:
    table = Tables.columntable(X)
    names = Tables.columnnames(table) |> collect
    A = Tables.matrix(table, transpose=true)

    lambda = learner.lambda

    # apply core algorithm:
    coefficients = (A*A' + learner.lambda*I)\(A*y) # vector

    # determine named coefficients:
    named_coefficients = [names[j] => coefficients[j] for j in eachindex(names)]

    # make some noise, if allowed:
    verbosity > 0 && @info "Coefficients: $named_coefficients"

    return RidgeFitted(learner, coefficients, named_coefficients)
end

Implementing predict

One way users will be able to call predict is like this:

predict(model, Point(), Xnew)

where Xnew is a table (of the same form as X above). The argument Point() signals that literal predictions of the target variable are sought, as opposed to some proxy for the target, such as probability density functions. Point is an example of a LearnAPI.KindOfProxy type. Targets and target proxies are discussed here.

We provide this implementation for our ridge regressor:

LearnAPI.predict(model::RidgeFitted, ::Point, Xnew) =
    Tables.matrix(Xnew)*model.coefficients

If the kind of proxy is omitted, as in predict(model, Xnew), then a fallback grabs the first element of the tuple returned by LearnAPI.kinds_of_proxy(learner), which we overload appropriately below.

Accessor functions

An accessor function has the output of fit as it's sole argument. Every new implementation must implement the accessor function LearnAPI.learner for recovering a learner from a fitted object:

LearnAPI.learner(model::RidgeFitted) = model.learner

Other accessor functions extract learned parameters or some standard byproducts of training, such as feature importances or training losses.² Here we implement an accessor function to extract the linear coefficients:

LearnAPI.coefficients(model::RidgeFitted) = model.named_coefficients

The LearnAPI.strip(model) accessor function is for returning a version of model suitable for serialization (typically smaller and data anonymized). It has a fallback that just returns model but for the sake of illustration, we overload it to dump the named version of the coefficients:

LearnAPI.strip(model::RidgeFitted) =
    RidgeFitted(model.learner, model.coefficients, nothing)

Crucially, we can still use LearnAPI.strip(model) in place of model to make new predictions.

Learner traits

Learner traits record extra generic information about a learner, or make specific promises of behavior. They are methods that have a learner as the sole argument, and so we regard LearnAPI.constructor defined above as a trait.

Because we have implemented predict, we are required to overload the LearnAPI.kinds_of_proxy trait. Because we can only make point predictions of the target, we make this definition:

LearnAPI.kinds_of_proxy(::Ridge) = (Point(),)

A macro provides a shortcut, convenient when multiple traits are to be defined:

@trait(
    Ridge,
    constructor = Ridge,
    kinds_of_proxy=(Point(),),
    tags = ("regression",),
    functions = (
        :(LearnAPI.fit),
        :(LearnAPI.learner),
        :(LearnAPI.clone),
        :(LearnAPI.strip),
        :(LearnAPI.obs),
        :(LearnAPI.features),
        :(LearnAPI.target),
        :(LearnAPI.predict),
        :(LearnAPI.coefficients),
   )
)

LearnAPI.functions (discussed further below) and LearnAPI.constructor, are the only universally compulsory traits. However, it is worthwhile studying the list of all traits to see which might apply to a new implementation, to enable maximum buy into functionality provided by third party packages, and to assist third party algorithms that match machine learning algorithms to user-defined tasks.

With some exceptions, the value of a trait should depend only on the type of the argument.

The functions trait

The last trait, functions, above returns a list of all LearnAPI.jl methods that can be meaningfully applied to the learner or associated model, with the exception of traits. You always include the first five you see here: fit, learner, clone ,strip, obs. Here clone is a utility function provided by LearnAPI that you never overload, while obs is discussed under Providing a separate data front end below and is always included because it has a meaningful fallback. The features method, here provided by a fallback, articulates how the features X can be extracted from the training data (X, y). We must also include target here to flag our model as supervised; again the method itself is provided by a fallback valid in the present case.

See LearnAPI.functions for a checklist of what the functions trait needs to return.

Signatures added for convenience

We add one fit signature for user-convenience only. The LearnAPI.jl specification has nothing to say about fit signatures with more than two positional arguments.

LearnAPI.fit(learner::Ridge, X, y; kwargs...) = fit(learner, (X, y); kwargs...)

Demonstration

We now illustrate how to interact directly with Ridge instances using the methods just implemented.

# synthesize some data:
n = 10 # number of observations
train = 1:6
test = 7:10
a, b, c = rand(n), rand(n), rand(n)
X = (; a, b, c)
y = 2a - b + 3c + 0.05*rand(n)
learner = Ridge(lambda=0.5)
@functions learner
(LearnAPI.fit, LearnAPI.learner, LearnAPI.clone, strip, LearnAPI.obs, LearnAPI.features, LearnAPI.target, LearnAPI.predict, LearnAPI.coefficients)

(Exact output may differ here because of way documentation is generated.)

Training and predicting:

Xtrain = Tables.subset(X, train)
ytrain = y[train]
model = fit(learner, (Xtrain, ytrain))  # `fit(learner, Xtrain, ytrain)` will also work
ŷ = predict(model, Tables.subset(X, test))
4-element Vector{Float64}:
 1.7761358846796615
 1.4856015442605508
 1.4023808165187088
 1.8352601939368751

Extracting coefficients:

LearnAPI.coefficients(model)
3-element Vector{Pair{Symbol, Float64}}:
 :a => 1.3247718263388135
 :b => 0.281105660760313
 :c => 2.1562208226524455

Serialization/deserialization:

using Serialization
small_model = LearnAPI.strip(model)
filename = tempname()
serialize(filename, small_model)
recovered_model = deserialize(filename)
@assert LearnAPI.learner(recovered_model) == learner
@assert predict(recovered_model, X) == predict(model, X)

Testing an implementation

using LearnTestAPI
@testapi learner (X, y) verbosity=0

Other data patterns

Here are some important remarks for implementations deviating in their assumptions about data from those made above.

  • New implementations of fit, predict, etc, always have a single data argument as above. For convenience, a signature such as fit(learner, table, formula), calling fit(learner, (table, formula)), can be added, but the LearnAPI.jl specification is silent on the meaning or existence of signatures with extra arguments.

  • If the data object consumed by fit, predict, or transform is not not a suitable table¹, array³, tuple of tables and arrays, or some other object implementing the MLCore.jl getobs/numobs interface, then an implementation must: (i) overload obs to articulate how provided data can be transformed into a form that does support this interface, as illustrated below under Providing a separate data front end below; or (ii) overload the trait LearnAPI.data_interface to specify a more relaxed data API.

  • Where the form of data consumed by fit is different from that consumed by predict/transform (as in classical supervised learning) it may be necessary to explicitly overload the functions LearnAPI.features and (if supervised) LearnAPI.target. The same holds if overloading obs; see below.

Providing a separate data front end

See here for code without explanations.

An implementation may optionally implement obs, to expose to the user (or some meta-algorithm like cross-validation) the representation of input data internal to fit or predict, such as the matrix version A of X in the ridge example. That is, we may factor out of fit (and also predict) a data preprocessing step, obs, to expose its outcomes. These outcomes become alternative user inputs to fit/predict.

The obs methods exist to:

  • Enable meta-algorithms to avoid redundant conversions of user-provided data into the form ultimately used by the core training algorithms.

  • Through the provision of canned data front ends, enable users to provide data in a variety of formats, while allowing new implementations to focus on core algorithms that consume a standardized, preprocessed, representation of that data.

Important

While many new learner implementations will want to adopt a canned data front end, such as those provided by LearnDataFrontEnds.jl, we focus here on a self-contained implementation of obs for the ridge example above, to show how it works.

In the typical case, where LearnAPI.data_interface is not overloaded, the alternative data representations must implement the MLCore.jl getobs/numobs interface for observation subsampling, which is generally all a user or meta-algorithm will need, before passing the data on to fit/predict, as you would the original data.

So, instead of the pattern

model = fit(learner, data)
predict(model, newdata)

one enables the following alternative:

observations = obs(learner, data) # preprocessed training data

# optional subsampling:
observations = MLCore.getobs(observations, train_indices)

model = fit(learner, observations)

newobservations = obs(model, newdata)

# optional subsampling:
newobservations = MLCore.getobs(observations, test_indices)

predict(model, newobservations)

which works for any non-static learner implementing predict, no matter how one is supposed to accesses the individual observations of data or newdata. See also the demonstration below. Furthermore, fallbacks ensure the above pattern still works if we choose not to implement a front end at all, which is allowed, if supported data and newdata already implement getobs/numobs.

Here we specifically wrap all the preprocessed data into single object, for which we introduce a new type:

struct RidgeFitObs{T,M<:AbstractMatrix{T}}
    A::M                  # `p` x `n` matrix
    names::Vector{Symbol} # features
    y::Vector{T}          # target
end

Now we overload obs to carry out the data preprocessing previously in fit, like this:

function LearnAPI.obs(::Ridge, data)
    X, y = data
    table = Tables.columntable(X)
    names = Tables.columnnames(table) |> collect
    return RidgeFitObs(Tables.matrix(table)', names, y)
end

We informally refer to the output of obs as "observations" (see The obs contract below). The previous core fit signature is now replaced with two methods - one to handle "regular" input, and one to handle the pre-processed data (observations) which appears first below:

function LearnAPI.fit(learner::Ridge, observations::RidgeFitObs; verbosity=1)

    lambda = learner.lambda

    A = observations.A
    names = observations.names
    y = observations.y

    # apply core learner:
    coefficients = (A*A' + learner.lambda*I)\(A*y) # 1 x p matrix

    # determine named coefficients:
    named_coefficients = [names[j] => coefficients[j] for j in eachindex(names)]

    # make some noise, if allowed:
    verbosity > 0 && @info "Coefficients: $named_coefficients"

    return RidgeFitted(learner, coefficients, named_coefficients)

end

LearnAPI.fit(learner::Ridge, data; kwargs...) =
    fit(learner, obs(learner, data); kwargs...)

The obs contract

Providing fit signatures matching the output of obs, is the first part of the obs contract. Since obs(learner, data) should evidently support all data that fit(learner, data) supports, we must be able to apply obs(learner, _) to it's own output (observations below). This leads to the additional declaration

LearnAPI.obs(::Ridge, observations::RidgeFitObs) = observations

In other words, we ensure that obs(learner, _) is involutive.

The second part of the obs contract is this: The output of obs must implement the interface specified by the trait LearnAPI.data_interface(learner). Assuming this is LearnAPI.RandomAccess() (the default) it usually suffices to overload Base.getindex and Base.length:

Base.getindex(data::RidgeFitObs, I) =
    RidgeFitObs(data.A[:,I], data.names, y[I])
Base.length(data::RidgeFitObs) = length(data.y)

We do something similar for predict, but there's no need for a new type in this case:

LearnAPI.obs(::RidgeFitted, Xnew) = Tables.matrix(Xnew)'
LearnAPI.obs(::RidgeFitted, observations::AbstractArray) = observations # involutivity

LearnAPI.predict(model::RidgeFitted, ::Point, observations::AbstractMatrix) =
    observations'*model.coefficients

LearnAPI.predict(model::RidgeFitted, ::Point, Xnew) =
    predict(model, Point(), obs(model, Xnew))

features and target methods

Two methods LearnAPI.features and LearnAPI.target articulate how features and target can be extracted from data consumed by LearnAPI.jl methods. Fallbacks provided by LearnAPI.jl sufficed in our basic implementation above. Here we must explicitly overload them, so that they also handle the output of obs(learner, data):

LearnAPI.features(::Ridge, observations::RidgeFitObs) = observations.A
LearnAPI.features(learner::Ridge, data) = LearnAPI.features(learner, obs(learner, data))
LearnAPI.target(::Ridge, observations::RidgeFitObs) = observations.y
LearnAPI.target(learner::Ridge, data) = LearnAPI.target(learner, obs(learner, data))

Important notes:

  • The observations to be consumed by fit are returned by obs(learner::Ridge, ...), while those consumed by predict are returned by obs(model::RidgeFitted, ...). We need the different signatures because the form of data consumed by fit and predict are generally different.

  • We need the adjoint operator, ', because the last dimension in arrays is the observation dimension, according to the MLCore.jl convention. Remember, Xnew is a table here.

Since LearnAPI.jl provides fallbacks for obs that simply return the unadulterated data argument, overloading obs is optional. This is provided data in publicized fit/predict signatures already consists only of objects implement the LearnAPI.RandomAccess interface (most tables¹, arrays³, and tuples thereof).

To opt out of supporting the MLCore.jl interface altogether, an implementation must overload the trait, LearnAPI.data_interface(learner). See Data interfaces for details.

Addition of signatures for user convenience

As above, we add a signature for convenience, which the LearnAPI.jl specification neither requires nor forbids:

LearnAPI.fit(learner::Ridge, X, y; kwargs...)  = fit(learner, (X, y); kwargs...)

Demonstration of an advanced obs workflow

We now can train and predict using internal data representations, resampled using the generic MLCore.jl interface:

import MLCore
learner = Ridge()
observations_for_fit = obs(learner, (X, y))
model = fit(learner, MLCore.getobs(observations_for_fit, train))
observations_for_predict = obs(model, X)
ẑ = predict(model, MLCore.getobs(observations_for_predict, test))
4-element Vector{Float64}:
 2.4123196959466293
 3.605206762304774
 3.4271925316373433
 0.8368147321800036
@assert ẑ == ŷ

For an application of obs to efficient cross-validation, see here.


¹ In LearnAPI.jl a table is any object X implementing the Tables.jl interface, additionally satisfying Tables.istable(X) == true and implementing DataAPI.nrow (and whence MLCore.numobs). Tables that are also (unnamed) tuples are disallowed.

² An implementation can provide further accessor functions, if necessary, but like the native ones, they must be included in the LearnAPI.functions declaration.

³ The last index must be the observation index.

⁴ The data = (X, y) pattern implemented here is not the only supported pattern. For, example, data might be (T, formula) where T is a table and formula is an R-style formula.