predict
, transform
and inverse_transform
predict(model, kind_of_proxy, data)
transform(model, data)
inverse_transform(model, data)
Versions without the data
argument may apply, for example in Density estimation.
Typical worklows
Train some supervised learner
:
model = fit(learner, (X, y))
Predict probability distributions:
ŷ = predict(model, Distribution(), Xnew)
Generate point predictions:
ŷ = predict(model, Point(), Xnew)
Train a dimension-reducing learner
:
model = fit(learner, X)
Xnew_reduced = transform(model, Xnew)
Apply an approximate right inverse:
inverse_transform(model, Xnew_reduced)
Fit and transform in one line:
transform(learner, data) # `fit` implied
An advanced workflow
fitobs = obs(learner, (X, y)) # learner-specific repr. of data
model = fit(learner, MLUtils.getobs(fitobs, 1:100))
predictobs = obs(model, MLUtils.getobs(X, 101:150))
ŷ = predict(model, Point(), predictobs)
Implementation guide
method | compulsory? | fallback |
---|---|---|
predict | no | none |
transform | no | none |
inverse_transform | no | none |
Predict or transform?
If the learner has a notion of target variable, then use predict
to output each supported kind of target proxy (Point()
, Distribution()
, etc).
For output not associated with a target variable, implement transform
instead, which does not dispatch on LearnAPI.KindOfProxy
, but can be optionally paired with an implementation of inverse_transform
, for returning (approximate) right or left inverses to transform
.
Of course, the one learner can implement both a predict
and transform
method. For example a K-means clustering algorithm can predict
labels and transform
to reduce dimension using distances from the cluster centres.
One-liners combining fit and transform/predict
Learners may additionally overload transform
to apply fit
first, using the supplied data if required, and then immediately transform
the same data. In that case the first argument of transform
is a learner instead of the output of fit
:
transform(learner, data) # `fit` implied
This will be shorthand for
model = fit(learner, X) # or `fit(learner)` in the static case
transform(model, X)
The same remarks apply to predict
, as in
predict(learner, kind_of_proxy, data) # `fit` implied
LearnAPI.jl does not, however, guarantee the provision of these one-liners.
Reference
LearnAPI.predict
— Functionpredict(model, kind_of_proxy::LearnAPI.KindOfProxy, data)
predict(model, data)
The first signature returns target predictions, or proxies for target predictions, for input features data
, according to some model
returned by fit
. Where supported, these are literally target predictions if kind_of_proxy = Point()
, and probability density/mass functions if kind_of_proxy = Distribution()
. List all options with LearnAPI.kinds_of_proxy(learner)
, where learner = LearnAPI.learner(model)
.
model = fit(learner, (X, y))
predict(model, Point(), Xnew)
The shortcut predict(model, data)
calls the first method with learner-specific kind_of_proxy
, namely the first element of LearnAPI.kinds_of_proxy(learner)
, which lists all supported target proxies.
The argument model
is anything returned by a call of the form fit(learner, ...)
.
If LearnAPI.features(LearnAPI.learner(model)) == nothing
, then the argument data
is omitted in both signatures. An example is density estimators.
See also fit
, transform
, inverse_transform
.
Extended help
In the special case LearnAPI.is_static(learner) == true
, it is possible that predict(model, ...)
will mutate model
, but not in a way that affects subsequent predict
calls.
New implementations
If there is no notion of a "target" variable in the LearnAPI.jl sense, or you need an operation with an inverse, implement transform
instead.
Implementation is optional. Only the first signature (with or without the data
argument) is implemented, but each kind_of_proxy::
KindOfProxy
that gets an implementation must be added to the list returned by LearnAPI.kinds_of_proxy(learner)
. List all available kinds of proxy by doing LearnAPI.kinds_of_proxy()
.
If data
is not present in the implemented signature (eg., for density estimators) then LearnAPI.features(learner, data)
must return nothing
.
If implemented, you must include :(LearnAPI.predict)
in the tuple returned by the LearnAPI.functions
trait.
If, additionally, LearnAPI.strip(model)
is overloaded, then the following identity must hold:
predict(LearnAPI.strip(model), args...) == predict(model, args...)
If LearnAPI.is_static(learner)
is true
, then predict
may mutate it's first argument (to record byproducts of the computation not naturally part of the return value) but not in a way that alters the result of a subsequent call to predict
, transform
or inverse_transform
. See more at fit
.
Assumptions about data
By default, it is assumed that data
supports the LearnAPI.RandomAccess
interface; this includes all matrices, with observations-as-columns, most tables, and tuples thereof. See LearnAPI.RandomAccess
for details. If this is not the case then an implementation must either: (i) overload obs
to articulate how provided data can be transformed into a form that does support LearnAPI.RandomAccess
; or (ii) overload the trait LearnAPI.data_interface
to specify a more relaxed data API. Refer tbo document strings for details.
LearnAPI.transform
— Functiontransform(model, data)
Return a transformation of some data
, using some model
, as returned by fit
.
Example
Below, X
and Xnew
are data of the same form.
For a learner
that generalizes to new data ("learns"):
model = fit(learner, X; verbosity=0)
transform(model, Xnew)
or, in one step (where supported):
W = transform(learner, X) # `fit` implied
For a static (non-generalizing) transformer:
model = fit(learner)
W = transform(model, X)
or, in one step (where supported):
W = transform(learner, X) # `fit` implied
In the special case LearnAPI.is_static(learner) == true
, it is possible that transform(model, ...)
will mutate model
, but not in a way that affects subsequent transform
calls.
See also fit
, predict
, inverse_transform
.
Extended help
New implementations
Implementation for new LearnAPI.jl learners is optional. If implemented, you must include :(LearnAPI.transform)
in the tuple returned by the LearnAPI.functions
trait.
An implementation is free to implement transform
signatures with additional positional arguments (eg., data-slurping signatures) but LearnAPI.jl is silent about their interpretation or existence.
If, additionally, LearnAPI.strip(model)
is overloaded, then the following identity must hold:
transform(LearnAPI.strip(model), args...) == transform(model, args...)
If LearnAPI.is_static(learner)
is true
, then transform
may mutate it's first argument (to record byproducts of the computation not naturally part of the return value) but not in a way that alters the result of a subsequent call to predict
, transform
or inverse_transform
. See more at fit
.
Assumptions about data
By default, it is assumed that data
supports the LearnAPI.RandomAccess
interface; this includes all matrices, with observations-as-columns, most tables, and tuples thereof. See LearnAPI.RandomAccess
for details. If this is not the case then an implementation must either: (i) overload obs
to articulate how provided data can be transformed into a form that does support LearnAPI.RandomAccess
; or (ii) overload the trait LearnAPI.data_interface
to specify a more relaxed data API. Refer tbo document strings for details.
LearnAPI.inverse_transform
— Functioninverse_transform(model, data)
Inverse transform data
according to some model
returned by fit
. Here "inverse" is to be understood broadly, e.g, an approximate right or left inverse for transform
.
Example
In the following, learner
is some dimension-reducing algorithm that generalizes to new data (such as PCA); Xtrain
is the training input and Xnew
the input to be reduced:
model = fit(learner, Xtrain)
W = transform(model, Xnew) # reduced version of `Xnew`
Ŵ = inverse_transform(model, W) # embedding of `W` in original space
See also fit
, transform
, predict
.
Extended help
New implementations
Implementation is optional. If implemented, you must include :(LearnAPI.inverse_transform)
in the tuple returned by the LearnAPI.functions
trait.
If, additionally, LearnAPI.strip(model)
is overloaded, then the following identity must hold:
inverse_transform(LearnAPI.strip(model), args...) == inverse_transform(model, args...)