predict, transform and inverse_transform

predict(model, kind_of_proxy, data)
transform(model, data)
inverse_transform(model, data)

Versions without the data argument may apply, for example in density estimation.

Typical worklows

Train some supervised learner:

model = fit(learner, (X, y))

Predict probability distributions:

ŷ = predict(model, Distribution(), Xnew)

Generate point predictions:

ŷ = predict(model, Point(), Xnew)

Train a dimension-reducing learner:

model = fit(learner, X)
Xnew_reduced = transform(model, Xnew)

Apply an approximate right inverse:

inverse_transform(model, Xnew_reduced)

Fit and transform in one line:

transform(learner, data) # `fit` implied

An advanced workflow

fitobs = obs(learner, (X, y)) # learner-specific repr. of data
model = fit(learner, MLCore.getobs(fitobs, 1:100))
predictobs = obs(model, MLCore.getobs(X, 101:150))
ŷ = predict(model, Point(), predictobs)

Implementation guide


Predict or transform?

If the learner has a notion of target variable, then use predict to output each supported kind of target proxy (Point(), Distribution(), etc).

For output not associated with a target variable, implement transform instead, which does not dispatch on LearnAPI.KindOfProxy, but can be optionally paired with an implementation of inverse_transform, for returning (approximate) right or left inverses to transform.

Of course, the one learner can implement both a predict and transform method. For example a K-means clustering algorithm can predict labels and transform to reduce dimension using distances from the cluster centres.

One-liners combining fit and transform/predict

Learners may additionally overload transform to apply fit first, using the supplied data if required, and then immediately transform the same data. In that case the first argument of transform is a learner instead of the output of fit:

transform(learner, data) # `fit` implied

This will be shorthand for

model = fit(learner, X) # or `fit(learner)` in the static case
transform(model, X)

The same remarks apply to predict, as in

predict(learner, kind_of_proxy, data) # `fit` implied

LearnAPI.jl does not, however, guarantee the provision of these one-liners.


predict(model, kind_of_proxy::LearnAPI.KindOfProxy, data)
predict(model, data)

The first signature returns target predictions, or proxies for target predictions, for input features data, according to some model returned by fit. Where supported, these are literally target predictions if kind_of_proxy = Point(), and probability density/mass functions if kind_of_proxy = Distribution(). List all options with LearnAPI.kinds_of_proxy(learner), where learner = LearnAPI.learner(model).

model = fit(learner, (X, y))
predict(model, Point(), Xnew)

The shortcut predict(model, data) calls the first method with learner-specific kind_of_proxy, namely the first element of LearnAPI.kinds_of_proxy(learner), which lists all supported target proxies.

The argument model is anything returned by a call of the form fit(learner, ...).

If LearnAPI.features(LearnAPI.learner(model)) == nothing, then the argument data is omitted in both signatures. An example is density estimators.

See also fit, transform, inverse_transform.

Extended help

In the special case LearnAPI.is_static(learner) == true, it is possible that predict(model, ...) will mutate model, but not in a way that affects subsequent predict calls.

New implementations

If there is no notion of a "target" variable in the LearnAPI.jl sense, or you need an operation with an inverse, implement transform instead.

Implementation is optional. Only the first signature (with or without the data argument) is implemented, but each kind_of_proxy::KindOfProxy that gets an implementation must be added to the list returned by LearnAPI.kinds_of_proxy(learner). List all available kinds of proxy by doing LearnAPI.kinds_of_proxy().

When predict is implemented, it may be necessary to overload LearnAPI.features. If data is not present in the implemented signature (eg., for density estimators) then LearnAPI.features(learner, data) must always return nothing.

If implemented, you must include :(LearnAPI.predict) in the tuple returned by the LearnAPI.functions trait.

If, additionally, LearnAPI.strip(model) is overloaded, then the following identity must hold:

predict(LearnAPI.strip(model), args...) == predict(model, args...)

If LearnAPI.is_static(learner) is true, then predict may mutate it's first argument (to record byproducts of the computation not naturally part of the return value) but not in a way that alters the result of a subsequent call to predict, transform or inverse_transform. See more at fit.

Assumptions about data

By default, it is assumed that data supports the LearnAPI.RandomAccess interface; this includes all matrices, with observations-as-columns, most tables, and tuples thereof. See LearnAPI.RandomAccess for details. If this is not the case then an implementation must either: (i) overload obs to articulate how provided data can be transformed into a form that does support LearnAPI.RandomAccess; or (ii) overload the trait LearnAPI.data_interface to specify a more relaxed data API. Refer tbo document strings for details.

transform(model, data)

Return a transformation of some data, using some model, as returned by fit.


Below, X and Xnew are data of the same form.

For a learner that generalizes to new data ("learns"):

model = fit(learner, X; verbosity=0)
transform(model, Xnew)

or, in one step (where supported):

W = transform(learner, X) # `fit` implied

For a static (non-generalizing) transformer:

model = fit(learner)
W = transform(model, X)

or, in one step (where supported):

W = transform(learner, X) # `fit` implied

In the special case LearnAPI.is_static(learner) == true, it is possible that transform(model, ...) will mutate model, but not in a way that affects subsequent transform calls.

See also fit, predict, inverse_transform.

Extended help

New implementations

Implementation for new LearnAPI.jl learners is optional.

When predict is implemented, it may be necessary to overload LearnAPI.features.

If implemented, you must include :(LearnAPI.transform) in the tuple returned by the LearnAPI.functions trait.

An implementation is free to implement transform signatures with additional positional arguments (eg., data-slurping signatures) but LearnAPI.jl is silent about their interpretation or existence.

If, additionally, LearnAPI.strip(model) is overloaded, then the following identity must hold:

transform(LearnAPI.strip(model), args...) == transform(model, args...)

If LearnAPI.is_static(learner) is true, then transform may mutate it's first argument (to record byproducts of the computation not naturally part of the return value) but not in a way that alters the result of a subsequent call to predict, transform or inverse_transform. See more at fit.

Assumptions about data

By default, it is assumed that data supports the LearnAPI.RandomAccess interface; this includes all matrices, with observations-as-columns, most tables, and tuples thereof. See LearnAPI.RandomAccess for details. If this is not the case then an implementation must either: (i) overload obs to articulate how provided data can be transformed into a form that does support LearnAPI.RandomAccess; or (ii) overload the trait LearnAPI.data_interface to specify a more relaxed data API. Refer tbo document strings for details.

inverse_transform(model, data)

Inverse transform data according to some model returned by fit. Here "inverse" is to be understood broadly, e.g, an approximate right or left inverse for transform.


In the following, learner is some dimension-reducing algorithm that generalizes to new data (such as PCA); Xtrain is the training input and Xnew the input to be reduced:

model = fit(learner, Xtrain)
W = transform(model, Xnew)       # reduced version of `Xnew`
Ŵ = inverse_transform(model, W)  # embedding of `W` in original space

See also fit, transform, predict.

Extended help

New implementations

Implementation is optional. If implemented, you must include :(LearnAPI.inverse_transform) in the tuple returned by the LearnAPI.functions trait.

If, additionally, LearnAPI.strip(model) is overloaded, then the following identity must hold:

inverse_transform(LearnAPI.strip(model), args...) == inverse_transform(model, args...)