predict
, transform
, and relatives
Standard methods:
predict(model, kind_of_proxy, data...) -> prediction
transform(model, data...) -> transformed_data
inverse_transform(model, data...) -> inverted_data
Methods consuming output, obsdata
, of data-preprocessor obs
:
obspredict(model, kind_of_proxy, obsdata) -> prediction
obstransform(model, obsdata) -> transformed_data
Typical worklows
# Train some supervised `algorithm`:
model = fit(algorithm, X, y)
# Predict probability distributions:
ŷ = predict(model, Distribution(), Xnew)
# Generate point predictions:
ŷ = predict(model, LiteralTarget(), Xnew)
# Training a dimension-reducing `algorithm`:
model = fit(algorithm, X)
Xnew_reduced = transform(model, Xnew)
# Apply an approximate right inverse:
inverse_transform(model, Xnew_reduced)
An advanced workflow
fitdata = obs(fit, algorithm, X, y)
predictdata = obs(predict, algorithm, Xnew)
model = obsfit(algorithm, obsdata)
ŷ = obspredict(model, LiteralTarget(), predictdata)
Implementation guide
The methods predict
and transform
are not directly overloaded. Implement obspredict
and obstransform
instead:
method | compulsory? | fallback | requires |
---|---|---|---|
obspredict | no | none | fit |
obstransform | no | none | fit |
inverse_transform | no | none | fit , obstransform |
Predict or transform?
If the algorithm has a notion of target variable, then arrange for obspredict
to output each supported kind of target proxy (LiteralTarget()
, Distribution()
, etc).
For output not associated with a target variable, implement obstransform
instead, which does not dispatch on LearnAPI.KindOfProxy
, but can be optionally paired with an implementation of inverse_transform
for returning (approximate) right inverses to transform
.
Reference
LearnAPI.predict
— Functionpredict(model, kind_of_proxy::LearnAPI.KindOfProxy, data...)
predict(model, data...)
The first signature returns target or target proxy predictions for input features data
, according to some model
returned by fit
or obsfit
. Where supported, these are literally target predictions if kind_of_proxy = LiteralTarget()
, and probability density/mass functions if kind_of_proxy = Distribution()
. List all options with LearnAPI.kinds_of_proxy(algorithm)
, where algorithm = LearnAPI.algorithm(model)
.
The shortcut predict(model, data...) = predict(model, LiteralTarget(), data...)
is also provided.
Arguments
model
is anything returned by a call of the formfit(algorithm, ...)
, for some LearnAPI-complaintalgorithm
.data
: tuple of data objects with a common number of observations, for example,data = (X, y, w)
whereX
is a table of features,y
is a target vector with the same number of rows, andw
a vector of per-observation weights.
Example
In the following, algorithm
is some supervised learning algorithm with training features X
, training target y
, and test features Xnew
:
model = fit(algorithm, X, y; verbosity=0)
predict(model, LiteralTarget(), Xnew)
Note predict
does not mutate any argument, except in the special case LearnAPI.predict_or_transform_mutates(algorithm) = true
.
See also obspredict
, fit
, transform
, inverse_transform
.
Extended help
New implementations
LearnAPI.jl provides the following definition of predict
which is never to be directly overloaded:
predict(model, kop::LearnAPI.KindOfProxy, data...) =
obspredict(model, kop, obs(predict, LearnAPI.algorithm(model), data...))
Rather, new algorithms overload obspredict
.
LearnAPI.obspredict
— Functionobspredict(model, kind_of_proxy::LearnAPI.KindOfProxy, obsdata)
Similar to predict
but consumes algorithm-specific representations of input data, obsdata
, as returned by obs(predict, algorithm, data...)
. Here data...
is the form of data expected in the main predict
method. Alternatively, such obsdata
may be replaced by a resampled version, where resampling is performed using MLUtils.getobs
(always supported).
For some algorithms and workflows, obspredict
will have a performance benefit over predict
. See more at obs
.
Example
In the following, algorithm
is some supervised learning algorithm with training features X
, training target y
, and test features Xnew
:
model = fit(algorithm, X, y)
obsdata = obs(predict, algorithm, Xnew)
ŷ = obspredict(model, LiteralTarget(), obsdata)
@assert ŷ == predict(model, LiteralTarget(), Xnew)
See also predict
, fit
, transform
, inverse_transform
, obs
.
Extended help
New implementations
Implementation of obspredict
is optional, but required to enable predict
. The method must also handle obsdata
in the case it is replaced by MLUtils.getobs(obsdata, I)
for some collection I
of indices. If obs
is not overloaded, then obsdata = data
, where data...
is what the standard predict
call expects, as in the call predict(model, kind_of_proxy, data...)
. Note data
is always a tuple, even if predict
has only one data argument. See more at obs
.
If LearnAPI.predict_or_transform_mutates(algorithm)
is overloaded to return true
, then obspredict
may mutate it's first argument, but not in a way that alters the result of a subsequent call to obspredict
, obstransform
or inverse_transform
. This is necessary for some non-generalizing algorithms but is otherwise discouraged. See more at fit
.
If overloaded, you must include both LearnAPI.obspredict
and LearnAPI.predict
in the list of methods returned by the LearnAPI.functions
trait.
An implementation is provided for each kind of target proxy you wish to support. See the LearnAPI.jl documentation for options. Each supported kind_of_proxy
instance should be listed in the return value of the LearnAPI.kinds_of_proxy(algorithm)
trait.
If, additionally, minimize(model)
is overloaded, then the following identity must hold:
obspredict(minimize(model), args...) = obspredict(model, args...)
LearnAPI.transform
— Functiontransform(model, data...)
Return a transformation of some data
, using some model
, as returned by fit
.
Arguments
model
is anything returned by a call of the formfit(algorithm, ...)
, for some LearnAPI-complaintalgorithm
.data
: tuple of data objects with a common number of observations, for example,data = (X, y, w)
whereX
is a table of features,y
is a target vector with the same number of rows, andw
a vector of per-observation weights.
Example
Here X
and Xnew
are data of the same form:
# For an algorithm that generalizes to new data ("learns"):
model = fit(algorithm, X; verbosity=0)
transform(model, Xnew)
# For a static (non-generalizing) transformer:
model = fit(algorithm)
transform(model, X)
Note transform
does not mutate any argument, except in the special case LearnAPI.predict_or_transform_mutates(algorithm) = true
.
See also obstransform
, fit
, predict
, inverse_transform
.
Extended help
New implementations
LearnAPI.jl provides the following definition of transform
which is never to be directly overloaded:
transform(model, data...) =
obstransform(model, obs(predict, LearnAPI.algorithm(model), data...))
Rather, new algorithms overload obstransform
.
LearnAPI.obstransform
— Functionobstransform(model, kind_of_proxy::LearnAPI.KindOfProxy, obsdata)
Similar to transform
but consumes algorithm-specific representations of input data, obsdata
, as returned by obs(transform, algorithm, data...)
. Here data...
is the form of data expected in the main transform
method. Alternatively, such obsdata
may be replaced by a resampled version, where resampling is performed using MLUtils.getobs
(always supported).
For some algorithms and workflows, obstransform
will have a performance benefit over transform
. See more at obs
.
Example
In the following, algorithm
is some unsupervised learning algorithm with training features X
, and test features Xnew
:
model = fit(algorithm, X, y)
obsdata = obs(transform, algorithm, Xnew)
W = obstransform(model, obsdata)
@assert W == transform(model, Xnew)
See also transform
, fit
, predict
, inverse_transform
, obs
.
Extended help
New implementations
Implementation of obstransform
is optional, but required to enable transform
. The method must also handle obsdata
in the case it is replaced by MLUtils.getobs(obsdata, I)
for some collection I
of indices. If obs
is not overloaded, then obsdata = data
, where data...
is what the standard transform
call expects, as in the call transform(model, data...)
. Note data
is always a tuple, even if transform
has only one data argument. See more at obs
.
If LearnAPI.predict_or_transform_mutates(algorithm)
is overloaded to return true
, then obstransform
may mutate it's first argument, but not in a way that alters the result of a subsequent call to obspredict
, obstransform
or inverse_transform
. This is necessary for some non-generalizing algorithms but is otherwise discouraged. See more at fit
.
If overloaded, you must include both LearnAPI.obstransform
and LearnAPI.transform
in the list of methods returned by the LearnAPI.functions
trait.
Each supported kind_of_proxy
should be listed in the return value of the LearnAPI.kinds_of_proxy(algorithm)
trait.
If, additionally, minimize(model)
is overloaded, then the following identity must hold:
obstransform(minimize(model), args...) = obstransform(model, args...)
LearnAPI.inverse_transform
— Functioninverse_transform(model, data)
Inverse transform data
according to some model
returned by fit
. Here "inverse" is to be understood broadly, e.g, an approximate right inverse for transform
.
Arguments
model
: anything returned by a call of the formfit(algorithm, ...)
, for some LearnAPI-complaintalgorithm
.data
: something having the same form as the output oftransform(model, inputs...)
Example
In the following, algorithm
is some dimension-reducing algorithm that generalizes to new data (such as PCA); Xtrain
is the training input and Xnew
the input to be reduced:
model = fit(algorithm, Xtrain; verbosity=0)
W = transform(model, Xnew) # reduced version of `Xnew`
Ŵ = inverse_transform(model, W) # embedding of `W` in original space
See also fit
, transform
, predict
.
Extended help
New implementations
Implementation is optional. If implemented, you must include inverse_transform
in the tuple returned by the LearnAPI.functions
trait.
If, additionally, minimize(model)
is overloaded, then the following identity must hold:
inverse_transform(minimize(model), args...) = inverse_transform(model, args...)