Testing an Implementation
Testing is provided by the LearnTestAPI.jl package documented below.
Quick start
LearnTestAPI
— ModuleLearnTestAPI
Module for testing implementations of the interface defined in LearnAPI.jl.
If your package defines an object learner
implementing the interface, then put something like this in your test suite:
using LearnTestAPI
# create some test data:
X = ...
y = ...
data = (X, y)
# bump verbosity to debug:
@testapi learner data verbosity=1
Once tests pass, set verbosity=0
to suppress the detailed logging.
For details and options see LearnTestAPI.@testapi
New releases of LearnTestAPI.jl may add tests to @testapi
, and this may result in new failures in client package test suites, because of previously undetected broken contracts. Adding a test to @testapi
is not considered a breaking change to LearnTestAPI, unless it supports a breaking change to LearnAPI.jl.
The @testapi macro
LearnTestAPI.@testapi
— Macro@testapi learner dataset1 dataset2 ... verbosity=1
Test that learner
correctly implements the LearnAPI.jl interface, by checking contracts against one or more data sets.
using LearnTestAPI
X = (
feature1 = [1, 2, 3],
feature2 = ["a", "b", "c"],
feature3 = [10.0, 20.0, 30.0],
)
@testapi MyFeatureSelector(; features=[:feature3,]) X verbosity=1
Extended help
Assumptions
In some tests strict ==
is enforced on the output of predict
or transform
, unless isapprox
is also implemented. If predict
outputs categorical vectors, for example, then requiring ==
in a test is appropriate. On the other hand, if predict
outputs some abstract vector of eltype Float32
, it will be necessary that isapprox
is implemented for that vector type, because the strict test ==
is likely to fail. These comments apply to more complicated objects, such as probability distributions or sampleable objects: If ==
is likely to fail in "benign" cases, be sure isapprox
is implemented. See LearnTestAPI.isnear
for the exact test applied.
What is not tested?
When verbosity=1
(the default) the test log describes all contracts tested.
The following are not tested:
That the output of
LearnAPI.target(learner, data)
is indeed a target, in the sense that it can be paired, in some way, with the output ofpredict
. Such a test would be to suitably pair the output with a predicted proxy for the target, using, for example, a proper scoring rule, in the case of probabilistic predictions.That
inverse_transform
is an approximate left or right inverse totransform
That the one-line convenience methods,
transform(learner, ...)
orpredict(learner, ...)
, where implemented, have the same effect as the two-line calls they combine.The veracity of
LearnAPI.is_pure_julia(learner)
.The second of the two contracts appearing in the
LearnAPI.target_observation_scitype
docstring. The first contract is only tested ifLearnAPI.data_interface(learner)
isLearnAPI.RandomAccess()
orLearnAPI.FiniteIterable()
.
Whenever the internal learner
algorithm involves case distinctions around data or hyperparameters, it is recommended that multiple datasets, and learners with a variety of hyperparameter settings, are explicitly tested.
Role of datasets in tests
Each dataset
is used as follows.
If LearnAPI.is_static(learner) == false
, then:
dataset
is passed tofit
and, if necessary, itsupdate
cousinsIf
X = LearnAPI.features(learner, dataset) == nothing
, thenpredict
and/ortransform
are called with no data. Otherwise, they are called withX
.
If instead LearnAPI.is_static(learner) == true
, then fit
and its cousins are called without any data, and dataset
is passed directly to predict
and/or transform
.
Learners for testing
LearnTestAPI.jl provides some simple, tested, LearnAPI.jl implementations, which may be useful for testing learner wrappers and meta-algorithms.
LearnTestAPI.Ridge
— TypeRidge(; lambda=0.1)
Instantiate a ridge regression learner, with regularization of lambda
. Data can be provided to fit
or predict
in any form supported by the Saffron
data front end at LearnDataFrontEnds.jl.
LearnTestAPI.BabyRidge
— TypeBabyRidge(; lambda=0.1)
Instantiate a ridge regression learner, with regularization of lambda
.
LearnTestAPI.ConstantClassifier
— TypeConstantClassifier()
Instantiate a constant (dummy) classifier. Can predict Point
or Distribution
targets.
LearnTestAPI.TruncatedSVD
— TypeTruncatedSVD(; codim=1)
Instantiate a truncated singular value decomposition algorithm for reducing the dimension of observations by codim
.
Data can be provided to fit
or transform
in any form supported by the Tarragon
data front end at LearnDataFrontEnds.jl. However, the outputs of transform
and inverse_transform
are always matrices.
learner = Truncated()
X = rand(3, 100) # 100 observations in 3-space
model = fit(learner, X)
W = transform(model, X)
X_reconstructed = inverse_transform(model, W)
LearnAPI.extras(model) # returns indim, outdim and singular values
The following fits and transforms in one go:
W = transform(learner, X)
LearnTestAPI.Selector
— TypeSelector(; names=Symbol[])
Instantiate a static transformer that selects only the feature specified by names
.
learner = Selector(names=[:x, :w])
X = DataFrames.DataFrame(rand(3, 4), [:x, :y, :z, :w])
model = fit(learner) # no data arguments!
W = transform(model, X)
# one-liner:
@assert transform(learner, X) == W
LearnTestAPI.FancySelector
— TypeFancySelector(; names=Symbol[])
Instantiate a feature selector that exposes the names of rejected features. Inputs for transform are expected to be tables.
learner = FancySelector(names=[:x, :w])
X = DataFrames.DataFrame(rand(3, 4), [:x, :y, :z, :w])
model = fit(learner) # no data arguments!
transform(model, X) # mutates `model`
@assert rejected(model) == [:y, :z]
LearnTestAPI.NormalEstimator
— TypeNormalEstimator()
Instantiate a learner for finding the maximum likelihood normal distribution fitting some real univariate data y
. Estimates can be updated with new data.
model = fit(NormalEstimator(), y)
d = predict(model) # returns the learned `Normal` distribution
While the above is equivalent to the single operation d = predict(NormalEstimator(), y)
, the above workflow allows for the presentation of additional observations post facto: The following is equivalent to d2 = predict(NormalEstimator(), vcat(y, ynew))
:
update_observations(model, ynew)
d2 = predict(model)
Inspect all learned parameters with LearnAPI.extras(model)
. Predict a 95% confidence interval with predict(model, ConfidenceInterval())
LearnTestAPI.Ensemble
— TypeEnsemble(atom; rng=Random.default_rng(), n=10)
Instantiate a bagged ensemble of n
regressors, with base regressor atom
, etc.
X = rand(3, 100)
y = rand(100)
atom = LearnAPI.Ridge()
learner = LearnTestAPI.Ensemble(atom, n=20)
model = fit(learner, (X, y))
# increase ensemble size by 5:
model = update(model, (X, y), :n => 15)
LearnTestAPI.StumpRegressor
— TypeStumpRegressor(; ntrees=10, fraction_train=0.8, rng=Random.default_rng())
Instantiate an extremely randomized forest of stump regressors, for training using the LearnAPI.jl interface, as in the example below. By default, 20% of the data is internally set aside to allow for tracking an out-of-sample loss. Internally computed predictions (on the full data) are also exposed to the user.
x = rand(100)
y = sin.(x)
learner = LearnTestAPI.StumpRegressor(ntrees=100)
# train regressor with 100 tree stumps, printing running out-of-sample loss:
model = fit(learner, (x, y), verbosity=2)
# add 400 stumps:
model = update(model, (x, y), ntrees=500)
# predict:
@assert predict(model, x) ≈ LearnAPI.predictions(model)
# inspect other byproducts of training:
LearnAPI.training_losses(model)
LearnAPI.out_of_sample_losses(model)
LearnAPI.trees(model)
Only univariate data is supported. Data is cached and update(model, data; ...)
ignores data
.
Algorithm
Predictions in this extremely simplistic algorithm (not intended for practical application) are averages over an ensemble of decision tree stumps. Each new stump has it's feature split threshold chosen uniformly at random, between the minimum and maximum values present in the training data. Predictions on new feature values to the left (resp., right) of the threshold are the mean target values for training observations in which the feature is less than (resp., greater than) the threshold.
Private methods
For LearnTestAPI.jl developers only, and subject to breaking changes at any time:
LearnTestAPI.@logged_testset
— MacroLearnTestAPI.@logged_testset message verbosity ex
Private method.
Similar to Test.@testset
with the exception that you can make message
longer without loss of clarity, and this message is reported according to the extra parameter verbosity
provided:
If
verbosity > 0
thenmessage
is always logged toLogging.Info
If
verbosity ≤ 0
thenmessage
is only logged if a@test
call inex
fails on evaluation.
Note that variables defined within the program ex
are local, and so are not available in subsequent @logged_testset
calls. However, the last evaluated expression in ex
is the return value.
julia> f(x) = x^2
julia> LearnTestAPI.@logged_testset "Testing `f`" 1 begin
y = f(2)
@test y > 0
y
end
[ Info: Testing `f`
4
julia> LearnTestAPI.@logged_testset "Testing `f`" 0 begin
y = f(2)
@test y < 0
y
end
Test Failed at REPL[41]:3
Expression: y < 0
Evaluated: 4 < 0
[ Info: Context of failure:
[ Info: Testing `f`
ERROR: There was an error during testing
LearnTestAPI.@nearly
— Macro@nearly lhs == rhs kwargs...
Private method.
Replaces the expression lhs == rhs
with isnear(lhs, rhs; kwargs...)
for testing a weaker form of equality. Here kwargs...
are keyword arguments accepted in isapprox(lhs, rhs; kwargs...)
, which is called if lhs == rhs
fails.
See also LearnTestAPI.isnear
.
LearnTestAPI.isnear
— Functionisnear(x, y; kwargs...)
Private method.
Returns true
if x == y
. Otherwise try to return isapprox(x, y; kwargs...)
.
julia> near('a', 'a')
true
julia> near('a', 'b')
[ Info: Tried testing for `≈` because `==` failed.
ERROR: MethodError: no method matching isapprox(::Char, ::Char)
LearnTestAPI.learner_get
— Functionlearner_get(learner, data, apply=identity)
Private method.
Extract from LearnAPI.obs(learner, data)
, after applying apply
, all observations, using the data access API specified by LearnAPI.data_interface(learner)
.
LearnTestAPI.model_get
— Functionmodel_get(model, data)
Private method.
Extract from LearnAPI.obs(model, data)
, after applying apply
, all observations, using the data access API specified by LearnAPI.data_interface(learner)
, where learner = LearnAPI.learner(model)
.
LearnTestAPI.verb
— FunctionLearaAPI.verb(ex)
Private method.
If ex
is a specification of verbosity
, such as :(verbosity=1)
, then return the specified value; otherwise, return nothing
.
LearnTestAPI.filter_out_verbosity
— FunctionLearnAPI.filter_out_verbosity(exs)
Private method.
Return (filtered_exs, verbosity)
where filtered_exs
is exs
with any verbosity
specification dropped, and verbosity
is the verbosity value (1
if not specified).