Testing an Implementation

Testing is provided by the LearnTestAPI.jl package documented below.

Quick start

LearnTestAPIModule
LearnTestAPI

Module for testing implementations of the interface defined in LearnAPI.jl.

If your package defines an object learner implementing the interface, then put something like this in your test suite:

using LearnTestAPI

# create some test data:
X = ...
y = ...
data = (X, y)

# bump verbosity to debug:
@testapi learner data verbosity=1

Once tests pass, set verbosity=0 to suppress the detailed logging.

For details and options see LearnTestAPI.@testapi

source
Warning

New releases of LearnTestAPI.jl may add tests to @testapi, and this may result in new failures in client package test suites, because of previously undetected broken contracts. Adding a test to @testapi is not considered a breaking change to LearnTestAPI, unless it supports a breaking change to LearnAPI.jl.

The @testapi macro

LearnTestAPI.@testapiMacro
@testapi learner dataset1 dataset2 ... verbosity=1

Test that learner correctly implements the LearnAPI.jl interface, by checking contracts against one or more data sets.

using LearnTestAPI

X = (
    feature1 = [1, 2, 3],
    feature2 = ["a", "b", "c"],
    feature3 = [10.0, 20.0, 30.0],
)

@testapi MyFeatureSelector(; features=[:feature3,]) X verbosity=1

Extended help

Assumptions

In some tests strict == is enforced on the output of predict or transform, unless isapprox is also implemented. If predict outputs categorical vectors, for example, then requiring == in a test is appropriate. On the other hand, if predict outputs some abstract vector of eltype Float32, it will be necessary that isapprox is implemented for that vector type, because the strict test == is likely to fail. These comments apply to more complicated objects, such as probability distributions or sampleable objects: If == is likely to fail in "benign" cases, be sure isapprox is implemented. See LearnTestAPI.isnear for the exact test applied.

What is not tested?

When verbosity=1 (the default) the test log describes all contracts tested.

The following are not tested:

  • That the output of LearnAPI.target(learner, data) is indeed a target, in the sense that it can be paired, in some way, with the output of predict. Such a test would be to suitably pair the output with a predicted proxy for the target, using, for example, a proper scoring rule, in the case of probabilistic predictions.

  • That inverse_transform is an approximate left or right inverse to transform

  • That the one-line convenience methods, transform(learner, ...) or predict(learner, ...), where implemented, have the same effect as the two-line calls they combine.

  • The veracity of LearnAPI.is_pure_julia(learner).

  • The second of the two contracts appearing in the LearnAPI.target_observation_scitype docstring. The first contract is only tested if LearnAPI.data_interface(learner) is LearnAPI.RandomAccess() or LearnAPI.FiniteIterable().

Whenever the internal learner algorithm involves case distinctions around data or hyperparameters, it is recommended that multiple datasets, and learners with a variety of hyperparameter settings, are explicitly tested.

Role of datasets in tests

Each dataset is used as follows.

If LearnAPI.is_static(learner) == false, then:

  • dataset is passed to fit and, if necessary, its update cousins

  • If X = LearnAPI.features(learner, dataset) == nothing, then predict and/or transform are called with no data. Otherwise, they are called with X.

If instead LearnAPI.is_static(learner) == true, then fit and its cousins are called without any data, and dataset is passed directly to predict and/or transform.

source

Learners for testing

LearnTestAPI.jl provides some simple, tested, LearnAPI.jl implementations, which may be useful for testing learner wrappers and meta-algorithms.

LearnTestAPI.RidgeType
Ridge(; lambda=0.1)

Instantiate a ridge regression learner, with regularization of lambda. Data can be provided to fit or predict in any form supported by the Saffron data front end at LearnDataFrontEnds.jl.

source
LearnTestAPI.TruncatedSVDType
TruncatedSVD(; codim=1)

Instantiate a truncated singular value decomposition algorithm for reducing the dimension of observations by codim.

Data can be provided to fit or transform in any form supported by the Tarragon data front end at LearnDataFrontEnds.jl. However, the outputs of transform and inverse_transform are always matrices.

learner = Truncated()
X = rand(3, 100)  # 100 observations in 3-space
model = fit(learner, X)
W = transform(model, X)
X_reconstructed = inverse_transform(model, W)
LearnAPI.extras(model) # returns indim, outdim and singular values

The following fits and transforms in one go:

W = transform(learner, X)
source
LearnTestAPI.SelectorType
Selector(; names=Symbol[])

Instantiate a static transformer that selects only the feature specified by names.

learner = Selector(names=[:x, :w])
X = DataFrames.DataFrame(rand(3, 4), [:x, :y, :z, :w])
model = fit(learner) # no data arguments!
W = transform(model, X)

# one-liner:
@assert transform(learner, X) == W
source
LearnTestAPI.FancySelectorType
FancySelector(; names=Symbol[])

Instantiate a feature selector that exposes the names of rejected features. Inputs for transform are expected to be tables.

learner = FancySelector(names=[:x, :w])
X = DataFrames.DataFrame(rand(3, 4), [:x, :y, :z, :w])
model = fit(learner) # no data arguments!
transform(model, X)  # mutates `model`
@assert rejected(model) == [:y, :z]
source
LearnTestAPI.NormalEstimatorType
NormalEstimator()

Instantiate a learner for finding the maximum likelihood normal distribution fitting some real univariate data y. Estimates can be updated with new data.

model = fit(NormalEstimator(), y)
d = predict(model) # returns the learned `Normal` distribution

While the above is equivalent to the single operation d = predict(NormalEstimator(), y), the above workflow allows for the presentation of additional observations post facto: The following is equivalent to d2 = predict(NormalEstimator(), vcat(y, ynew)):

update_observations(model, ynew)
d2 = predict(model)

Inspect all learned parameters with LearnAPI.extras(model). Predict a 95% confidence interval with predict(model, ConfidenceInterval())

source
LearnTestAPI.EnsembleType
Ensemble(atom; rng=Random.default_rng(), n=10)

Instantiate a bagged ensemble of n regressors, with base regressor atom, etc.

X = rand(3, 100)
y = rand(100)

atom = LearnAPI.Ridge()
learner = LearnTestAPI.Ensemble(atom, n=20)
model = fit(learner, (X, y))

# increase ensemble size by 5:
model = update(model, (X, y), :n => 15)
source
LearnTestAPI.StumpRegressorType
StumpRegressor(; ntrees=10, fraction_train=0.8, rng=Random.default_rng())

Instantiate an extremely randomized forest of stump regressors, for training using the LearnAPI.jl interface, as in the example below. By default, 20% of the data is internally set aside to allow for tracking an out-of-sample loss. Internally computed predictions (on the full data) are also exposed to the user.

x = rand(100)
y = sin.(x)

learner = LearnTestAPI.StumpRegressor(ntrees=100)

# train regressor with 100 tree stumps, printing running out-of-sample loss:
model = fit(learner, (x, y), verbosity=2)

# add 400 stumps:
model = update(model, (x, y), ntrees=500)

# predict:
@assert predict(model, x) ≈ LearnAPI.predictions(model)

# inspect other byproducts of training:
LearnAPI.training_losses(model)
LearnAPI.out_of_sample_losses(model)
LearnAPI.trees(model)

Only univariate data is supported. Data is cached and update(model, data; ...) ignores data.

Algorithm

Predictions in this extremely simplistic algorithm (not intended for practical application) are averages over an ensemble of decision tree stumps. Each new stump has it's feature split threshold chosen uniformly at random, between the minimum and maximum values present in the training data. Predictions on new feature values to the left (resp., right) of the threshold are the mean target values for training observations in which the feature is less than (resp., greater than) the threshold.

source

Private methods

For LearnTestAPI.jl developers only, and subject to breaking changes at any time:

LearnTestAPI.@logged_testsetMacro
LearnTestAPI.@logged_testset message verbosity ex

Private method.

Similar to Test.@testset with the exception that you can make message longer without loss of clarity, and this message is reported according to the extra parameter verbosity provided:

  • If verbosity > 0 then message is always logged to Logging.Info

  • If verbosity ≤ 0 then message is only logged if a @test call in ex fails on evaluation.

Note that variables defined within the program ex are local, and so are not available in subsequent @logged_testset calls. However, the last evaluated expression in ex is the return value.

julia> f(x) = x^2
julia> LearnTestAPI.@logged_testset "Testing `f`" 1 begin
           y = f(2)
           @test y > 0
           y
       end
[ Info: Testing `f`
4

julia> LearnTestAPI.@logged_testset "Testing `f`" 0 begin
           y = f(2)
           @test y < 0
           y
       end
Test Failed at REPL[41]:3
  Expression: y < 0
   Evaluated: 4 < 0

[ Info: Context of failure:
[ Info: Testing `f`
ERROR: There was an error during testing
source
LearnTestAPI.@nearlyMacro
@nearly lhs == rhs kwargs...

Private method.

Replaces the expression lhs == rhs with isnear(lhs, rhs; kwargs...) for testing a weaker form of equality. Here kwargs... are keyword arguments accepted in isapprox(lhs, rhs; kwargs...), which is called if lhs == rhs fails.

See also LearnTestAPI.isnear.

source
LearnTestAPI.isnearFunction
isnear(x, y; kwargs...)

Private method.

Returns true if x == y. Otherwise try to return isapprox(x, y; kwargs...).

julia> near('a', 'a')
true

julia> near('a', 'b')
[ Info: Tried testing for `≈` because `==` failed.
ERROR: MethodError: no method matching isapprox(::Char, ::Char)
source
LearnTestAPI.learner_getFunction
learner_get(learner, data, apply=identity)

Private method.

Extract from LearnAPI.obs(learner, data), after applying apply, all observations, using the data access API specified by LearnAPI.data_interface(learner).

source
LearnTestAPI.model_getFunction
model_get(model, data)

Private method.

Extract from LearnAPI.obs(model, data), after applying apply, all observations, using the data access API specified by LearnAPI.data_interface(learner), where learner = LearnAPI.learner(model).

source
LearnTestAPI.verbFunction
LearaAPI.verb(ex)

Private method.

If ex is a specification of verbosity, such as :(verbosity=1), then return the specified value; otherwise, return nothing.

source
LearnTestAPI.filter_out_verbosityFunction
LearnAPI.filter_out_verbosity(exs)

Private method.

Return (filtered_exs, verbosity) where filtered_exs is exs with any verbosity specification dropped, and verbosity is the verbosity value (1 if not specified).

source