Anatomy of an Implementation
The core LearnAPI.jl pattern looks like this:
model = fit(learner, data)
predict(model, newdata)
Here learner
specifies hyperparameters, while model
stores learned parameters and any byproducts of algorithm execution.
Variations on this pattern:
Transformers ordinarily implement
transform
instead ofpredict
. For more onpredict
versustransform
, see Predict or transform?"Static" (non-generalizing) algorithms, which includes some simple transformers and some clustering algorithms, have a
fit
that consumes nodata
. Insteadpredict
ortransform
does the heavy lifting.In density estimation, the
newdata
argument inpredict
is missing.
These are the basic possibilities.
Elaborating on the core pattern above, this tutorial details an implementation of the LearnAPI.jl for naive ridge regression with no intercept. The kind of workflow we want to enable has been previewed in Sample workflow. Readers can also refer to the demonstration of the implementation given later.
A basic implementation
See here for code without explanations.
We suppose our algorithm's fit
method consumes data in the form (X, y)
, where X
is a suitable table¹ (the features) and y
a vector (the target).
Implementations wishing to support other data patterns may need to take additional steps explained under Other data patterns below.
The first line below imports the lightweight package LearnAPI.jl whose methods we will be extending. The second imports libraries needed for the core algorithm.
using LearnAPI
using LinearAlgebra, Tables
Defining learners
Here's a new type whose instances specify the single ridge regression hyperparameter:
struct Ridge{T<:Real}
lambda::T
end
Instances of Ridge
are learners, in LearnAPI.jl parlance.
Associated with each new type of LearnAPI.jl learner will be a keyword argument constructor, providing default values for all properties (typically, struct fields) that are not other learners, and we must implement LearnAPI.constructor(learner)
, for recovering the constructor from an instance:
"""
Ridge(; lambda=0.1)
Instantiate a ridge regression learner, with regularization of `lambda`.
"""
Ridge(; lambda=0.1) = Ridge(lambda)
LearnAPI.constructor(::Ridge) = Ridge
For example, in this case, if learner = Ridge(0.2)
, then LearnAPI.constructor(learner)(lambda=0.2) == learner
is true. Note that we attach the docstring to the constructor, not the struct.
Implementing fit
A ridge regressor requires two types of data for training: input features X
, which here we suppose are tabular¹, and a target y
, which we suppose is a vector.⁴
It is convenient to define a new type for the fit
output, which will include coefficients labelled by feature name for inspection after training:
struct RidgeFitted{T,F}
learner::Ridge
coefficients::Vector{T}
named_coefficients::F
end
Note that we also include learner
in the struct, for it must be possible to recover learner
from the output of fit
; see Accessor functions below.
The implementation of fit
looks like this:
function LearnAPI.fit(learner::Ridge, data; verbosity=1)
X, y = data
# data preprocessing:
table = Tables.columntable(X)
names = Tables.columnnames(table) |> collect
A = Tables.matrix(table, transpose=true)
lambda = learner.lambda
# apply core algorithm:
coefficients = (A*A' + learner.lambda*I)\(A*y) # vector
# determine named coefficients:
named_coefficients = [names[j] => coefficients[j] for j in eachindex(names)]
# make some noise, if allowed:
verbosity > 0 && @info "Coefficients: $named_coefficients"
return RidgeFitted(learner, coefficients, named_coefficients)
end
Implementing predict
One way users will be able to call predict
is like this:
predict(model, Point(), Xnew)
where Xnew
is a table (of the same form as X
above). The argument Point()
signals that literal predictions of the target variable are sought, as opposed to some proxy for the target, such as probability density functions. Point
is an example of a LearnAPI.KindOfProxy
type. Targets and target proxies are discussed here.
We provide this implementation for our ridge regressor:
LearnAPI.predict(model::RidgeFitted, ::Point, Xnew) =
Tables.matrix(Xnew)*model.coefficients
If the kind of proxy is omitted, as in predict(model, Xnew)
, then a fallback grabs the first element of the tuple returned by LearnAPI.kinds_of_proxy(learner)
, which we overload appropriately below.
Accessor functions
An accessor function has the output of fit
as it's sole argument. Every new implementation must implement the accessor function LearnAPI.learner
for recovering a learner from a fitted object:
LearnAPI.learner(model::RidgeFitted) = model.learner
Other accessor functions extract learned parameters or some standard byproducts of training, such as feature importances or training losses.² Here we implement an accessor function to extract the linear coefficients:
LearnAPI.coefficients(model::RidgeFitted) = model.named_coefficients
The LearnAPI.strip(model)
accessor function is for returning a version of model
suitable for serialization (typically smaller and data anonymized). It has a fallback that just returns model
but for the sake of illustration, we overload it to dump the named version of the coefficients:
LearnAPI.strip(model::RidgeFitted) =
RidgeFitted(model.learner, model.coefficients, nothing)
Crucially, we can still use LearnAPI.strip(model)
in place of model
to make new predictions.
Learner traits
Learner traits record extra generic information about a learner, or make specific promises of behavior. They are methods that have a learner as the sole argument, and so we regard LearnAPI.constructor
defined above as a trait.
Because we have implemented predict
, we are required to overload the LearnAPI.kinds_of_proxy
trait. Because we can only make point predictions of the target, we make this definition:
LearnAPI.kinds_of_proxy(::Ridge) = (Point(),)
A macro provides a shortcut, convenient when multiple traits are to be defined:
@trait(
Ridge,
constructor = Ridge,
kinds_of_proxy=(Point(),),
tags = ("regression",),
functions = (
:(LearnAPI.fit),
:(LearnAPI.learner),
:(LearnAPI.clone),
:(LearnAPI.strip),
:(LearnAPI.obs),
:(LearnAPI.features),
:(LearnAPI.target),
:(LearnAPI.predict),
:(LearnAPI.coefficients),
)
)
LearnAPI.functions
(discussed further below) and LearnAPI.constructor
, are the only universally compulsory traits. However, it is worthwhile studying the list of all traits to see which might apply to a new implementation, to enable maximum buy into functionality provided by third party packages, and to assist third party algorithms that match machine learning algorithms to user-defined tasks.
With some exceptions, the value of a trait should depend only on the type of the argument.
The functions
trait
The last trait, functions
, above returns a list of all LearnAPI.jl methods that can be meaningfully applied to the learner or associated model, with the exception of traits. You always include the first five you see here: fit
, learner
, clone
,strip
, obs
. Here clone
is a utility function provided by LearnAPI that you never overload, while obs
is discussed under Providing a separate data front end below and is always included because it has a meaningful fallback. The features
method, here provided by a fallback, articulates how the features X
can be extracted from the training data (X, y)
. We must also include target
here to flag our model as supervised; again the method itself is provided by a fallback valid in the present case.
See LearnAPI.functions
for a checklist of what the functions
trait needs to return.
Signatures added for convenience
We add one fit
signature for user-convenience only. The LearnAPI.jl specification has nothing to say about fit
signatures with more than two positional arguments.
LearnAPI.fit(learner::Ridge, X, y; kwargs...) = fit(learner, (X, y); kwargs...)
Demonstration
We now illustrate how to interact directly with Ridge
instances using the methods just implemented.
# synthesize some data:
n = 10 # number of observations
train = 1:6
test = 7:10
a, b, c = rand(n), rand(n), rand(n)
X = (; a, b, c)
y = 2a - b + 3c + 0.05*rand(n)
learner = Ridge(lambda=0.5)
@functions learner
(LearnAPI.fit, LearnAPI.learner, LearnAPI.clone, strip, LearnAPI.obs, LearnAPI.features, LearnAPI.target, LearnAPI.predict, LearnAPI.coefficients)
(Exact output may differ here because of way documentation is generated.)
Training and predicting:
Xtrain = Tables.subset(X, train)
ytrain = y[train]
model = fit(learner, (Xtrain, ytrain)) # `fit(learner, Xtrain, ytrain)` will also work
ŷ = predict(model, Tables.subset(X, test))
4-element Vector{Float64}:
1.7761358846796615
1.4856015442605508
1.4023808165187088
1.8352601939368751
Extracting coefficients:
LearnAPI.coefficients(model)
3-element Vector{Pair{Symbol, Float64}}:
:a => 1.3247718263388135
:b => 0.281105660760313
:c => 2.1562208226524455
Serialization/deserialization:
using Serialization
small_model = LearnAPI.strip(model)
filename = tempname()
serialize(filename, small_model)
recovered_model = deserialize(filename)
@assert LearnAPI.learner(recovered_model) == learner
@assert predict(recovered_model, X) == predict(model, X)
Testing an implementation
using LearnTestAPI
@testapi learner (X, y) verbosity=0
Other data patterns
Here are some important remarks for implementations deviating in their assumptions about data from those made above.
New implementations of
fit
,predict
, etc, always have a singledata
argument as above. For convenience, a signature such asfit(learner, table, formula)
, callingfit(learner, (table, formula))
, can be added, but the LearnAPI.jl specification is silent on the meaning or existence of signatures with extra arguments.If the
data
object consumed byfit
,predict
, ortransform
is not not a suitable table¹, array³, tuple of tables and arrays, or some other object implementing the MLCore.jlgetobs
/numobs
interface, then an implementation must: (i) overloadobs
to articulate how provided data can be transformed into a form that does support this interface, as illustrated below under Providing a separate data front end below; or (ii) overload the traitLearnAPI.data_interface
to specify a more relaxed data API.Where the form of data consumed by
fit
is different from that consumed bypredict/transform
(as in classical supervised learning) it may be necessary to explicitly overload the functionsLearnAPI.features
and (if supervised)LearnAPI.target
. The same holds if overloadingobs
; see below.
Providing a separate data front end
See here for code without explanations.
An implementation may optionally implement obs
, to expose to the user (or some meta-algorithm like cross-validation) the representation of input data internal to fit
or predict
, such as the matrix version A
of X
in the ridge example. That is, we may factor out of fit
(and also predict
) a data preprocessing step, obs
, to expose its outcomes. These outcomes become alternative user inputs to fit
/predict
.
The obs
methods exist to:
Enable meta-algorithms to avoid redundant conversions of user-provided data into the form ultimately used by the core training algorithms.
Through the provision of canned data front ends, enable users to provide data in a variety of formats, while allowing new implementations to focus on core algorithms that consume a standardized, preprocessed, representation of that data.
While many new learner implementations will want to adopt a canned data front end, such as those provided by LearnDataFrontEnds.jl, we focus here on a self-contained implementation of obs
for the ridge example above, to show how it works.
In the typical case, where LearnAPI.data_interface
is not overloaded, the alternative data representations must implement the MLCore.jl getobs/numobs
interface for observation subsampling, which is generally all a user or meta-algorithm will need, before passing the data on to fit
/predict
, as you would the original data.
So, instead of the pattern
model = fit(learner, data)
predict(model, newdata)
one enables the following alternative:
observations = obs(learner, data) # preprocessed training data
# optional subsampling:
observations = MLCore.getobs(observations, train_indices)
model = fit(learner, observations)
newobservations = obs(model, newdata)
# optional subsampling:
newobservations = MLCore.getobs(observations, test_indices)
predict(model, newobservations)
which works for any non-static learner implementing predict
, no matter how one is supposed to accesses the individual observations of data
or newdata
. See also the demonstration below. Furthermore, fallbacks ensure the above pattern still works if we choose not to implement a front end at all, which is allowed, if supported data
and newdata
already implement getobs
/numobs
.
Here we specifically wrap all the preprocessed data into single object, for which we introduce a new type:
struct RidgeFitObs{T,M<:AbstractMatrix{T}}
A::M # `p` x `n` matrix
names::Vector{Symbol} # features
y::Vector{T} # target
end
Now we overload obs
to carry out the data preprocessing previously in fit
, like this:
function LearnAPI.obs(::Ridge, data)
X, y = data
table = Tables.columntable(X)
names = Tables.columnnames(table) |> collect
return RidgeFitObs(Tables.matrix(table)', names, y)
end
We informally refer to the output of obs
as "observations" (see The obs
contract below). The previous core fit
signature is now replaced with two methods - one to handle "regular" input, and one to handle the pre-processed data (observations) which appears first below:
function LearnAPI.fit(learner::Ridge, observations::RidgeFitObs; verbosity=1)
lambda = learner.lambda
A = observations.A
names = observations.names
y = observations.y
# apply core learner:
coefficients = (A*A' + learner.lambda*I)\(A*y) # 1 x p matrix
# determine named coefficients:
named_coefficients = [names[j] => coefficients[j] for j in eachindex(names)]
# make some noise, if allowed:
verbosity > 0 && @info "Coefficients: $named_coefficients"
return RidgeFitted(learner, coefficients, named_coefficients)
end
LearnAPI.fit(learner::Ridge, data; kwargs...) =
fit(learner, obs(learner, data); kwargs...)
The obs
contract
Providing fit
signatures matching the output of obs
, is the first part of the obs
contract. Since obs(learner, data)
should evidently support all data
that fit(learner, data)
supports, we must be able to apply obs(learner, _)
to it's own output (observations
below). This leads to the additional declaration
LearnAPI.obs(::Ridge, observations::RidgeFitObs) = observations
In other words, we ensure that obs(learner, _)
is involutive.
The second part of the obs
contract is this: The output of obs
must implement the interface specified by the trait LearnAPI.data_interface(learner)
. Assuming this is LearnAPI.RandomAccess()
(the default) it usually suffices to overload Base.getindex
and Base.length
:
Base.getindex(data::RidgeFitObs, I) =
RidgeFitObs(data.A[:,I], data.names, y[I])
Base.length(data::RidgeFitObs) = length(data.y)
We do something similar for predict
, but there's no need for a new type in this case:
LearnAPI.obs(::RidgeFitted, Xnew) = Tables.matrix(Xnew)'
LearnAPI.obs(::RidgeFitted, observations::AbstractArray) = observations # involutivity
LearnAPI.predict(model::RidgeFitted, ::Point, observations::AbstractMatrix) =
observations'*model.coefficients
LearnAPI.predict(model::RidgeFitted, ::Point, Xnew) =
predict(model, Point(), obs(model, Xnew))
features
and target
methods
Two methods LearnAPI.features
and LearnAPI.target
articulate how features and target can be extracted from data
consumed by LearnAPI.jl methods. Fallbacks provided by LearnAPI.jl sufficed in our basic implementation above. Here we must explicitly overload them, so that they also handle the output of obs(learner, data)
:
LearnAPI.features(::Ridge, observations::RidgeFitObs) = observations.A
LearnAPI.features(learner::Ridge, data) = LearnAPI.features(learner, obs(learner, data))
LearnAPI.target(::Ridge, observations::RidgeFitObs) = observations.y
LearnAPI.target(learner::Ridge, data) = LearnAPI.target(learner, obs(learner, data))
Important notes:
The observations to be consumed by
fit
are returned byobs(learner::Ridge, ...)
, while those consumed bypredict
are returned byobs(model::RidgeFitted, ...)
. We need the different signatures because the form of data consumed byfit
andpredict
are generally different.We need the adjoint operator,
'
, because the last dimension in arrays is the observation dimension, according to the MLCore.jl convention. Remember,Xnew
is a table here.
Since LearnAPI.jl provides fallbacks for obs
that simply return the unadulterated data argument, overloading obs
is optional. This is provided data in publicized fit
/predict
signatures already consists only of objects implement the LearnAPI.RandomAccess
interface (most tables¹, arrays³, and tuples thereof).
To opt out of supporting the MLCore.jl interface altogether, an implementation must overload the trait, LearnAPI.data_interface(learner)
. See Data interfaces for details.
Addition of signatures for user convenience
As above, we add a signature for convenience, which the LearnAPI.jl specification neither requires nor forbids:
LearnAPI.fit(learner::Ridge, X, y; kwargs...) = fit(learner, (X, y); kwargs...)
Demonstration of an advanced obs
workflow
We now can train and predict using internal data representations, resampled using the generic MLCore.jl interface:
import MLCore
learner = Ridge()
observations_for_fit = obs(learner, (X, y))
model = fit(learner, MLCore.getobs(observations_for_fit, train))
observations_for_predict = obs(model, X)
ẑ = predict(model, MLCore.getobs(observations_for_predict, test))
4-element Vector{Float64}:
2.4123196959466293
3.605206762304774
3.4271925316373433
0.8368147321800036
@assert ẑ == ŷ
For an application of obs
to efficient cross-validation, see here.
¹ In LearnAPI.jl a table is any object X
implementing the Tables.jl interface, additionally satisfying Tables.istable(X) == true
and implementing DataAPI.nrow
(and whence MLCore.numobs
). Tables that are also (unnamed) tuples are disallowed.
² An implementation can provide further accessor functions, if necessary, but like the native ones, they must be included in the LearnAPI.functions
declaration.
³ The last index must be the observation index.
⁴ The data = (X, y)
pattern implemented here is not the only supported pattern. For, example, data
might be (T, formula)
where T
is a table and formula
is an R-style formula.