predict, transform, and relatives

Standard methods:

predict(model, kind_of_proxy, data...) -> prediction
transform(model, data...) -> transformed_data
inverse_transform(model, data...) -> inverted_data

Methods consuming output, obsdata, of data-preprocessor obs:

obspredict(model, kind_of_proxy, obsdata) -> prediction
obstransform(model, obsdata) -> transformed_data

Typical worklows

# Train some supervised `algorithm`:
model = fit(algorithm, X, y)

# Predict probability distributions:
ŷ = predict(model, Distribution(), Xnew)

# Generate point predictions:
ŷ = predict(model, LiteralTarget(), Xnew)
# Training a dimension-reducing `algorithm`:
model = fit(algorithm, X)
Xnew_reduced = transform(model, Xnew)

# Apply an approximate right inverse:
inverse_transform(model, Xnew_reduced)

An advanced workflow

fitdata = obs(fit, algorithm, X, y)
predictdata = obs(predict, algorithm, Xnew)
model = obsfit(algorithm, obsdata)
ŷ = obspredict(model, LiteralTarget(), predictdata)

Implementation guide

The methods predict and transform are not directly overloaded. Implement obspredict and obstransform instead:

methodcompulsory?fallbackrequires
obspredictnononefit
obstransformnononefit
inverse_transformnononefit, obstransform

Predict or transform?

If the algorithm has a notion of target variable, then arrange for obspredict to output each supported kind of target proxy (LiteralTarget(), Distribution(), etc).

For output not associated with a target variable, implement obstransform instead, which does not dispatch on LearnAPI.KindOfProxy, but can be optionally paired with an implementation of inverse_transform for returning (approximate) right inverses to transform.

Reference

LearnAPI.predictFunction
predict(model, kind_of_proxy::LearnAPI.KindOfProxy, data...)
predict(model, data...)

The first signature returns target or target proxy predictions for input features data, according to some model returned by fit or obsfit. Where supported, these are literally target predictions if kind_of_proxy = LiteralTarget(), and probability density/mass functions if kind_of_proxy = Distribution(). List all options with LearnAPI.kinds_of_proxy(algorithm), where algorithm = LearnAPI.algorithm(model).

The shortcut predict(model, data...) = predict(model, LiteralTarget(), data...) is also provided.

Arguments

  • model is anything returned by a call of the form fit(algorithm, ...), for some LearnAPI-complaint algorithm.

  • data: tuple of data objects with a common number of observations, for example, data = (X, y, w) where X is a table of features, y is a target vector with the same number of rows, and w a vector of per-observation weights.

Example

In the following, algorithm is some supervised learning algorithm with training features X, training target y, and test features Xnew:

model = fit(algorithm, X, y; verbosity=0)
predict(model, LiteralTarget(), Xnew)

Note predict does not mutate any argument, except in the special case LearnAPI.predict_or_transform_mutates(algorithm) = true.

See also obspredict, fit, transform, inverse_transform.

Extended help

New implementations

LearnAPI.jl provides the following definition of predict which is never to be directly overloaded:

predict(model, kop::LearnAPI.KindOfProxy, data...) =
    obspredict(model, kop, obs(predict, LearnAPI.algorithm(model), data...))

Rather, new algorithms overload obspredict.

source
LearnAPI.obspredictFunction
obspredict(model, kind_of_proxy::LearnAPI.KindOfProxy, obsdata)

Similar to predict but consumes algorithm-specific representations of input data, obsdata, as returned by obs(predict, algorithm, data...). Here data... is the form of data expected in the main predict method. Alternatively, such obsdata may be replaced by a resampled version, where resampling is performed using MLUtils.getobs (always supported).

For some algorithms and workflows, obspredict will have a performance benefit over predict. See more at obs.

Example

In the following, algorithm is some supervised learning algorithm with training features X, training target y, and test features Xnew:

model = fit(algorithm, X, y)
obsdata = obs(predict, algorithm, Xnew)
ŷ = obspredict(model, LiteralTarget(), obsdata)
@assert ŷ == predict(model, LiteralTarget(), Xnew)

See also predict, fit, transform, inverse_transform, obs.

Extended help

New implementations

Implementation of obspredict is optional, but required to enable predict. The method must also handle obsdata in the case it is replaced by MLUtils.getobs(obsdata, I) for some collection I of indices. If obs is not overloaded, then obsdata = data, where data... is what the standard predict call expects, as in the call predict(model, kind_of_proxy, data...). Note data is always a tuple, even if predict has only one data argument. See more at obs.

If LearnAPI.predict_or_transform_mutates(algorithm) is overloaded to return true, then obspredict may mutate it's first argument, but not in a way that alters the result of a subsequent call to obspredict, obstransform or inverse_transform. This is necessary for some non-generalizing algorithms but is otherwise discouraged. See more at fit.

If overloaded, you must include both LearnAPI.obspredict and LearnAPI.predict in the list of methods returned by the LearnAPI.functions trait.

An implementation is provided for each kind of target proxy you wish to support. See the LearnAPI.jl documentation for options. Each supported kind_of_proxy instance should be listed in the return value of the LearnAPI.kinds_of_proxy(algorithm) trait.

If, additionally, minimize(model) is overloaded, then the following identity must hold:

obspredict(minimize(model), args...) = obspredict(model, args...)
source
LearnAPI.transformFunction
transform(model, data...)

Return a transformation of some data, using some model, as returned by fit.

Arguments

  • model is anything returned by a call of the form fit(algorithm, ...), for some LearnAPI-complaint algorithm.

  • data: tuple of data objects with a common number of observations, for example, data = (X, y, w) where X is a table of features, y is a target vector with the same number of rows, and w a vector of per-observation weights.

Example

Here X and Xnew are data of the same form:

# For an algorithm that generalizes to new data ("learns"):
model = fit(algorithm, X; verbosity=0)
transform(model, Xnew)

# For a static (non-generalizing) transformer:
model = fit(algorithm)
transform(model, X)

Note transform does not mutate any argument, except in the special case LearnAPI.predict_or_transform_mutates(algorithm) = true.

See also obstransform, fit, predict, inverse_transform.

Extended help

New implementations

LearnAPI.jl provides the following definition of transform which is never to be directly overloaded:

transform(model, data...) =
    obstransform(model, obs(predict, LearnAPI.algorithm(model), data...))

Rather, new algorithms overload obstransform.

source
LearnAPI.obstransformFunction
obstransform(model, kind_of_proxy::LearnAPI.KindOfProxy, obsdata)

Similar to transform but consumes algorithm-specific representations of input data, obsdata, as returned by obs(transform, algorithm, data...). Here data... is the form of data expected in the main transform method. Alternatively, such obsdata may be replaced by a resampled version, where resampling is performed using MLUtils.getobs (always supported).

For some algorithms and workflows, obstransform will have a performance benefit over transform. See more at obs.

Example

In the following, algorithm is some unsupervised learning algorithm with training features X, and test features Xnew:

model = fit(algorithm, X, y)
obsdata = obs(transform, algorithm, Xnew)
W = obstransform(model, obsdata)
@assert W == transform(model, Xnew)

See also transform, fit, predict, inverse_transform, obs.

Extended help

New implementations

Implementation of obstransform is optional, but required to enable transform. The method must also handle obsdata in the case it is replaced by MLUtils.getobs(obsdata, I) for some collection I of indices. If obs is not overloaded, then obsdata = data, where data... is what the standard transform call expects, as in the call transform(model, data...). Note data is always a tuple, even if transform has only one data argument. See more at obs.

If LearnAPI.predict_or_transform_mutates(algorithm) is overloaded to return true, then obstransform may mutate it's first argument, but not in a way that alters the result of a subsequent call to obspredict, obstransform or inverse_transform. This is necessary for some non-generalizing algorithms but is otherwise discouraged. See more at fit.

If overloaded, you must include both LearnAPI.obstransform and LearnAPI.transform in the list of methods returned by the LearnAPI.functions trait.

Each supported kind_of_proxy should be listed in the return value of the LearnAPI.kinds_of_proxy(algorithm) trait.

If, additionally, minimize(model) is overloaded, then the following identity must hold:

obstransform(minimize(model), args...) = obstransform(model, args...)
source
LearnAPI.inverse_transformFunction
inverse_transform(model, data)

Inverse transform data according to some model returned by fit. Here "inverse" is to be understood broadly, e.g, an approximate right inverse for transform.

Arguments

  • model: anything returned by a call of the form fit(algorithm, ...), for some LearnAPI-complaint algorithm.

  • data: something having the same form as the output of transform(model, inputs...)

Example

In the following, algorithm is some dimension-reducing algorithm that generalizes to new data (such as PCA); Xtrain is the training input and Xnew the input to be reduced:

model = fit(algorithm, Xtrain; verbosity=0)
W = transform(model, Xnew)       # reduced version of `Xnew`
Ŵ = inverse_transform(model, W)  # embedding of `W` in original space

See also fit, transform, predict.

Extended help

New implementations

Implementation is optional. If implemented, you must include inverse_transform in the tuple returned by the LearnAPI.functions trait.

If, additionally, minimize(model) is overloaded, then the following identity must hold:

inverse_transform(minimize(model), args...) = inverse_transform(model, args...)
source