Anatomy of an Implementation
Summary. Formally, an algorithm is a container for the hyperparameters of some ML/statistics algorithm. A basic implementation of the ridge regressor requires implementing
fit
andpredict
methods dispatched on the algorithm type;predict
is an example of an operation, the others aretransform
andinverse_transform
. In this example we also implement an accessor function, calledfeature_importance
, returning the absolute values of the linear coefficients. The ridge regressor has a target variable and outputs literal predictions of the target (rather than, say, probabilistic predictions); accordingly the overloadedpredict
method is dispatched on theLiteralTarget
subtype ofKindOfProxy
. An algorithm trait declares this type as the preferred kind of target proxy. Other traits articulate the algorithm's training data type requirements and the input/output type ofpredict
.
We begin by describing an implementation of LearnAPI.jl for basic ridge regression (without intercept) to introduce the main actors in any implementation.
Defining an algorithm type
The first line below imports the lightweight package LearnAPI.jl whose methods we will be extending, the second, libraries needed for the core algorithm.
using LearnAPI
using LinearAlgebra, Tables
Next, we define a struct to store the single hyperparameter lambda
of this algorithm:
struct MyRidge <: LearnAPI.Algorithm
lambda::Float64
end
The subtyping MyRidge <: LearnAPI.Algorithm
is optional but recommended where it is not otherwise disruptive.
Instances of MyRidge
are called algorithms and MyRidge
is an algorithm type.
A keyword argument constructor providing defaults for all hyperparameters should be provided:
MyRidge(; lambda=0.1) = MyRidge(lambda)
Implementing training (fit)
A ridge regressor requires two types of data for training: input features X
and a target y
. Training is implemented by overloading fit
. Here verbosity
is an integer (0
should train silently, unless warnings are needed):
function LearnAPI.fit(algorithm::MyRidge, verbosity, X, y)
# process input:
x = Tables.matrix(X) # convert table to matrix
s = Tables.schema(X)
features = s.names
# core solver:
coefficients = (x'x + algorithm.lambda*I)\(x'y)
# prepare output - learned parameters:
fitted_params = (; coefficients)
# prepare output - algorithm state:
state = nothing # not relevant here
# prepare output - byproducts of training:
feature_importances =
[features[j] => abs(coefficients[j]) for j in eachindex(features)]
sort!(feature_importances, by=last) |> reverse!
verbosity > 0 && @info "Features in order of importance: $(first.(feature_importances))"
report = (; feature_importances)
return fitted_params, state, report
end
Regarding the return value of fit
:
The
fitted_params
variable is for the algorithm's learned parameters, for passing topredict
(see below).The
state
variable is only relevant when additionally implementing aLearnAPI.update!
orLearnAPI.ingest!
method (see Fit, update! and ingest!).The
report
is for other byproducts of training, apart from the learned parameters (the ones we'll need to providepredict
below).
Our fit
method assumes that X
is a table (satisfies the Tables.jl spec) whose rows are the observations; and it will need need y
to be an AbstractFloat
vector. An algorithm implementation is free to dictate the representation of data that fit
accepts but articulates its requirements using appropriate traits; see Training data types below. We recommend against data type checks internal to fit
; this would ordinarily be the responsibility of a higher level API, using those traits.
Operations
Now we need a method for predicting the target on new input features:
function LearnAPI.predict(::MyRidge, ::LearnAPI.LiteralTarget, fitted_params, Xnew)
Xmatrix = Tables.matrix(Xnew)
report = nothing
return Xmatrix*fitted_params.coefficients, report
end
The second argument of predict
is always an instance of KindOfProxy
, and will always be LiteralTarget()
in this case, as only literal values of the target (rather than, say probabilistic predictions) are being supported.
In some algorithms predict
computes something of interest in addition to the target prediction, and this report
item is returned as the second component of the return value. When there's nothing to report, we must return nothing
, as here.
Our predict
method is an example of an operation. Other operations include transform
and inverse_transform
and an algorithm can implement more than one. For example, a K-means clustering algorithm might implement transform
for dimension reduction, and predict
to return cluster labels.
The predict
method is reserved for predictions of a target variable, and only predict
has the extra ::KindOfProxy
argument.
Accessor functions
The arguments of an operation are always (algorithm, fitted_params, data...)
. The interface also provides accessor functions for extracting information, from the fitted_params
and/or fit report
, that is shared by several algorithm types. There is one for feature importances that we can implement for MyRidge
:
LearnAPI.feature_importances(::MyRidge, fitted_params, report) =
report.feature_importances
Another example of an accessor function is LearnAPI.training_losses
.
Algorithm traits
We have implemented predict
, and it is possible to implement predict
methods for multiple KindOfProxy
types (see See Target proxies for a complete list). Accordingly, we are required to declare a preferred target proxy, which we do using LearnAPI.preferred_kind_of_proxy
:
LearnAPI.preferred_kind_of_proxy(::MyRidge) = LearnAPI.LiteralTarget()
Or, you can use the shorthand
@trait MyRidge preferred_kind_of_proxy=LearnAPI.LiteralTarget()
LearnAPI.preferred_kind_of_proxy
is an example of a algorithm trait. A complete list of traits and the contracts they imply is given in Algorithm Traits.
We also need to indicate that a target variable appears in training (this is a supervised algorithm). We do this by declaring where in the list of training data arguments (in this case (X, y)
) the target variable (in this case y
) appears:
@trait MyRidge position_of_target=2
As explained in the introduction, LearnAPI.jl does not attempt to define strict algorithm categories, such as "regression" or "clustering". However, we can optionally specify suggestive descriptors, as in
@trait MyRidge descriptors=(:regression,)
This declaration actually promises nothing, but can help in generating documentation. Do LearnAPI.descriptors()
to get a list of available descriptors.
Finally, we are required to declare what methods (excluding traits) we have explicitly overloaded for our type:
@trait MyRidge methods=(
:fit,
:predict,
:feature_importances,
)
Training data types
Since LearnAPI.jl is a basement level API, one is discouraged from including explicit type checks in an implementation of fit
. Instead one uses traits to make promises about the acceptable type of data
consumed by fit
. In general, this can be a promise regarding the ordinary type of data
or the scientific type of data
(but not both). Alternatively, one may only promise a bound on the type/scitype of observations in the data . See Algorithm Traits for further details. In this case we'll be happy to restrict the scitype of the data:
import ScientificTypesBase: scitype, Table, Continuous
@trait MyRidge fit_scitype = Tuple{Table(Continuous), AbstractVector{Continuous}}
This is a contract that data
is acceptable in the call fit(algorithm, verbosity, data...)
whenever
scitype(data) <: Tuple{Table(Continuous), AbstractVector{Continuous}}
Or, in other words:
X
infit(algorithm, verbosity, X, y)
is acceptable, providedscitype(X) <: Table(Continuous)
- meaning thatX
Tables.istable(X) == true
(see Tables.jl) and each column has some<:AbstractFloat
element type.y
infit(algorithm, verbosity, X, y)
is acceptable ifscitype(y) <: AbstractVector{Continuous}
- meaning that it is an abstract vector with<:AbstractFloat
elements.
Input types for operations
An optional promise about what data
is guaranteed to work in a call like predict(algorithm, fitted_params, data...)
is articulated this way:
@trait MyRidge predict_input_scitype = Tuple{AbstractVector{<:Continuous}}
Note that data
is always a Tuple
, even if it has only one component (the typical case), which explains the Tuple
on the right-hand side.
Optionally, we may express our promise using regular types, using the LearnAPI.predict_input_type
trait.
One can optionally make promises about the outut of an operation. See Algorithm Traits for details.
Illustrative fit/predict workflow
We now illustrate how to interact directly with MyRidge
instances using the methods we have implemented:
Here's some toy data for supervised learning:
using Tables
n = 10 # number of training observations
train = 1:6
test = 7:10
a, b, c = rand(n), rand(n), rand(n)
X = (; a, b, c) |> Tables.rowtable
y = 2a - b + 3c + 0.05*rand(n)
Instantiate an algorithm with relevant hyperparameters (which is all the object stores):
algorithm = MyRidge(lambda=0.5)
Main.MyRidge(0.5)
Train the algorithm (the 0
means do so silently):
import LearnAPI: fit, predict, feature_importances
fitted_params, state, fit_report = fit(algorithm, 0, X[train], y[train])
((coefficients = [1.7193532777568572, -0.0188372523133215, 1.7544276092776423],), nothing, (feature_importances = [:c => 1.7544276092776423, :a => 1.7193532777568572, :b => 0.0188372523133215],))
Inspect the learned parameters and report:
@info "training outcomes" fitted_params fit_report
┌ Info: training outcomes
│ fitted_params = (coefficients = [1.7193532777568572, -0.0188372523133215, 1.7544276092776423],)
└ fit_report = (feature_importances = [:c => 1.7544276092776423, :a => 1.7193532777568572, :b => 0.0188372523133215],)
Inspect feature importances:
feature_importances(algorithm, fitted_params, fit_report)
3-element Vector{Pair{Symbol, Float64}}:
:c => 1.7544276092776423
:a => 1.7193532777568572
:b => 0.0188372523133215
Make a prediction using new data:
yhat, predict_report = predict(algorithm, LearnAPI.LiteralTarget(), fitted_params, X[test])
([1.9250853557685823, 1.6835960082978159, 1.4790364479108098, 2.378366395814834], nothing)
Compare predictions with ground truth
deviations = yhat - y[test]
loss = deviations .^2 |> sum
@info "Sum of squares loss" loss
┌ Info: Sum of squares loss
└ loss = 1.4475141172409245