Algorithm Traits
Traits generally promise specific algorithm behavior, such as: This algorithm supports per-observation weights, which must appear as the third argument of fit
, or This algorithm's transform
method predicts Real
vectors. They also record more mundane information, such as a package license.
Algorithm traits are functions whose first (and usually only) argument is an algorithm.
Special two-argument traits
The two-argument version of LearnAPI.predict_output_scitype
and LearnAPI.predict_output_scitype
are the only overloadable traits with more than one argument.
Trait summary
Overloadable traits
In the examples column of the table below, Table
, Continuous
, Sampleable
are names owned by the package ScientificTypesBase.jl.
trait | return value | fallback value | example |
---|---|---|---|
LearnAPI.functions (algorithm) | functions you can apply to algorithm or associated model (traits excluded) | () | (LearnAPI.fit, LearnAPI.predict, LearnAPI.algorithm) |
LearnAPI.kinds_of_proxy (algorithm) | instances kop of KindOfProxy for which an implementation of LearnAPI.predict(algorithm, kop, ...) is guaranteed. | () | (Distribution(), Interval()) |
LearnAPI.position_of_target (algorithm) | the positional index¹ of the target in data in fit(algorithm, data...) calls | 0 | 2 |
LearnAPI.position_of_weights (algorithm) | the positional index¹ of per-observation weights in data in fit(algorithm, data...) | 0 | 3 |
LearnAPI.descriptors (algorithm) | lists one or more suggestive algorithm descriptors from LearnAPI.descriptors() | () | (:regression, :probabilistic) |
LearnAPI.is_pure_julia (algorithm) | true if implementation is 100% Julia code | false | true |
LearnAPI.pkg_name (algorithm) | name of package providing core code (may be different from package providing LearnAPI.jl implementation) | "unknown" | "DecisionTree" |
LearnAPI.pkg_license (algorithm) | name of license of package providing core code | "unknown" | "MIT" |
LearnAPI.doc_url (algorithm) | url providing documentation of the core code | "unknown" | "https://en.wikipedia.org/wiki/Decision_tree_learning" |
LearnAPI.load_path (algorithm) | a string indicating where the struct for typeof(algorithm) is defined, beginning with name of package providing implementation | "unknown" | FastTrees.LearnAPI.DecisionTreeClassifier |
LearnAPI.is_composite (algorithm) | true if one or more properties (fields) of algorithm may be an algorithm | false | true |
LearnAPI.human_name (algorithm) | human name for the algorithm; should be a noun | type name with spaces | "elastic net regressor" |
LearnAPI.iteration_parameter (algorithm) | symbolic name of an iteration parameter | nothing | :epochs |
LearnAPI.fit_scitype (algorithm) | upper bound on scitype(data) ensuring fit(algorithm, data...) works | Union{} | Tuple{Table(Continuous), AbstractVector{Continuous}} |
LearnAPI.fit_observation_scitype (algorithm) | upper bound on scitype(observation) for observation in data ensuring fit(algorithm, data...) works | Union{} | Tuple{AbstractVector{Continuous}, Continuous} |
LearnAPI.fit_type (algorithm) | upper bound on typeof(data) ensuring fit(algorithm, data...) works | Union{} | Tuple{AbstractMatrix{<:Real}, AbstractVector{<:Real}} |
LearnAPI.fit_observation_type (algorithm) | upper bound on typeof(observation) for observation in data ensuring fit(algorithm, data...) works | Union{} | Tuple{AbstractVector{<:Real}, Real} |
LearnAPI.predict_input_scitype (algorithm) | upper bound on scitype(data) ensuring predict(model, kop, data...) works | Union{} | Table(Continuous) |
LearnAPI.predict_input_observation_scitype (algorithm) | upper bound on scitype(observation) for observation in data ensuring predict(model, kop, data...) works | Union{} | Vector{Continuous} |
LearnAPI.predict_input_type (algorithm) | upper bound on typeof(data) ensuring predict(model, kop, data...) works | Union{} | AbstractMatrix{<:Real} |
LearnAPI.predict_input_observation_type (algorithm) | upper bound on typeof(observation) for observation in data ensuring predict(model, kop, data...) works | Union{} | Vector{<:Real} |
LearnAPI.predict_output_scitype (algorithm, kind_of_proxy) | upper bound on scitype(predict(model, ...)) | Any | AbstractVector{Continuous} |
LearnAPI.predict_output_type (algorithm, kind_of_proxy) | upper bound on typeof(predict(model, ...)) | Any | AbstractVector{<:Real} |
LearnAPI.transform_input_scitype (algorithm) | upper bound on scitype(data) ensuring transform(model, data...) works | Union{} | Table(Continuous) |
LearnAPI.transform_input_observation_scitype (algorithm) | upper bound on scitype(observation) for observation in data ensuring transform(model, data...) works | Union{} | Vector{Continuous} |
LearnAPI.transform_input_type (algorithm) | upper bound on typeof(data) ensuring transform(model, data...) works | Union{} | AbstractMatrix{<:Real}} |
LearnAPI.transform_input_observation_type (algorithm) | upper bound on typeof(observation) for observation in data ensuring transform(model, data...) works | Union{} | Vector{Continuous} |
LearnAPI.transform_output_scitype (algorithm) | upper bound on scitype(transform(model, ...)) | Any | Table(Continuous) |
LearnAPI.transform_output_type (algorithm) | upper bound on typeof(transform(model, ...)) | Any | AbstractMatrix{<:Real} |
LearnAPI.predict_or_transform_mutates (algorithm) | true if predict or transform mutates first argument | false | true |
¹ If the value is 0
, then the variable in boldface type is not supported and not expected to appear in data
. If length(data)
is less than the trait value, then data
is understood to exclude the variable, but note that fit
can have multiple signatures of varying lengths, as in fit(algorithm, X, y)
and fit(algorithm, X, y, w)
. A non-zero value is a promise that fit
includes a signature of sufficient length to include the variable.
Derived Traits
The following convenience methods are provided but not overloadable by new implementations.
trait | return value | example |
---|---|---|
LearnAPI.name(algorithm) | algorithm type name as string | "PCA" |
LearnAPI.is_algorithm(algorithm) | true if LearnAPI.functions(algorithm) is not empty | true |
LearnAPI.predict_output_scitype(algorithm) | dictionary of upper bounds on the scitype of predictions, keyed on subtypes of LearnAPI.KindOfProxy | |
LearnAPI.predict_output_type(algorithm) | dictionary of upper bounds on the type of predictions, keyed on subtypes of LearnAPI.KindOfProxy |
Implementation guide
A single-argument trait is declared following this pattern:
LearnAPI.is_pure_julia(algorithm::MyAlgorithmType) = true
A shorthand for single-argument traits is available:
@trait MyAlgorithmType is_pure_julia=true
Multiple traits can be declared like this:
@trait(
MyAlgorithmType,
is_pure_julia = true,
pkg_name = "MyPackage",
)
The global trait contracts
To ensure that trait metadata can be stored in an external algorithm registry, LearnAPI.jl requires:
Finiteness: The value of a trait is the same for all algorithms with same underlying
UnionAll
type. That is, even if the type parameters are different, the trait should be the same. There is an exception ifis_composite(algorithm) = true
.Serializability: The value of any trait can be evaluated without installing any third party package;
using LearnAPI
should suffice.
Because of 1, combining a lot of functionality into one algorithm (e.g. the algorithm can perform both classification or regression) can mean traits are necessarily less informative (as in LearnAPI.predict_type(algorithm) = Any
).
Reference
LearnAPI.functions
— FunctionLearnAPI.functions(algorithm)
Return a tuple of functions that can be sensibly applied to algorithm
, or to objects having the same type as algorithm
, or to associated models (objects returned by fit(algorithm, ...)
. Algorithm traits are excluded.
In addition to functions, the returned tuple may include expressions, like :(DecisionTree.print_tree)
, which reference functions not owned by LearnAPI.jl.
The understanding is that algorithm
is a LearnAPI-compliant object whenever this is non-empty.
Extended help
New implementations
All new implementations must overload this trait. Here's a checklist for elements in the return value:
function | needs explicit implementation? | include in returned tuple? |
---|---|---|
fit | no | yes |
obsfit | yes | yes |
minimize | optional | yes |
predict | no | if obspredict is implemented |
obspredict | optional | if implemented |
transform | no | if obstransform is implemented |
obstransform | optional | if implemented |
obs | optional | yes |
inverse_transform | optional | if implemented |
LearnAPI.algorithm | yes | yes |
Also include any implemented accessor functions. The LearnAPI.jl accessor functions are: LearnAPI.extras
, LearnAPI.algorithm
, LearnAPI.coefficients
, LearnAPI.intercept
, LearnAPI.tree
, LearnAPI.trees
, LearnAPI.feature_importances
, LearnAPI.training_labels
, LearnAPI.training_losses
, LearnAPI.training_scores
and LearnAPI.components
.
LearnAPI.kinds_of_proxy
— FunctionLearnAPI.kinds_of_proxy(algorithm)
Returns an tuple of all instances, kind
, for which for which predict(algorithm, kind, data...)
has a guaranteed implementation. Each such kind
subtypes LearnAPI.KindOfProxy
. Examples are LiteralTarget()
(for predicting actual target values) and Distributions()
(for predicting probability mass/density functions).
See also LearnAPI.predict
, LearnAPI.KindOfProxy
.
Extended help
New implementations
Implementation is optional but recommended whenever predict
is overloaded.
Elements of the returned tuple must be one of these: ConfidenceInterval
, Continuous
, Distribution
, LabelAmbiguous
, LabelAmbiguousDistribution
, LabelAmbiguousSampleable
, LiteralTarget
, LogDistribution
, LogProbability
, OutlierScore
, Parametric
, ProbabilisticSet
, Probability
, Sampleable
, Set
, SurvivalDistribution
, SurvivalFunction
, IID
, JointDistribution
, JointLogDistribution
and JointSampleable
.
Suppose, for example, we have the following implementation of a supervised learner returning only probabilistic predictions:
LearnAPI.predict(algorithm::MyNewAlgorithmType, LearnAPI.Distribution(), Xnew) = ...
Then we can declare
@trait MyNewAlgorithmType kinds_of_proxy = (LearnaAPI.Distribution(),)
For more on target variables and target proxies, refer to the LearnAPI documentation.
LearnAPI.position_of_target
— FunctionLearnAPI.position_of_target(algorithm)
Return the expected position of the target variable within data
in calls of the form LearnAPI.fit
(algorithm, verbosity, data...)
.
If this number is 0
, then no target is expected. If this number exceeds length(data)
, then data
is understood to exclude the target variable.
LearnAPI.position_of_weights
— FunctionLearnAPI.position_of_weights(algorithm)
Return the expected position of per-observation weights within data
in calls of the form LearnAPI.fit
(algorithm, data...)
.
If this number is 0
, then no weights are expected. If this number exceeds length(data)
, then data
is understood to exclude weights, which are assumed to be uniform.
LearnAPI.descriptors
— FunctionLearnAPI.descriptors(algorithm)
Lists one or more suggestive algorithm descriptors from this list: :regression
, :classification
, :clustering
, :gradient_descent
, :iterative_algorithms
, :incremental_algorithms
, :dimension_reduction
, :encoders
, :static_algorithms
, :missing_value_imputation
, :ensemble_algorithms
, :wrappers
, :time_series_forecasting
, :time_series_classification
, :survival_analysis
, :distribution_fitters
, :Bayesian_algorithms
, :outlier_detection
, :collaborative_filtering
, :text_analysis
, :audio_analysis
, :natural_language_processing
, :image_processing
(do LearnAPI.descriptors()
to reproduce).
The value of this trait guarantees no particular behavior. The trait is intended for informal classification purposes only.
New implementations
This trait should return a tuple of symbols, as in (:classifier, :text_analysis)
.
LearnAPI.is_pure_julia
— FunctionLearnAPI.is_pure_julia(algorithm)
Returns true
if training algorithm
requires evaluation of pure Julia code only.
New implementations
The fallback is false
.
LearnAPI.pkg_name
— FunctionLearnAPI.pkg_name(algorithm)
Return the name of the package module which supplies the core training algorithm for algorithm
. This is not necessarily the package providing the LearnAPI interface.
Returns "unknown"
if the algorithm implementation has failed to overload the trait.
New implementations
Must return a string, as in "DecisionTree"
.
LearnAPI.pkg_license
— FunctionLearnAPI.pkg_license(algorithm)
Return the name of the software license, such as "MIT"
, applying to the package where the core algorithm for algorithm
is implemented.
LearnAPI.doc_url
— FunctionLearnAPI.doc_url(algorithm)
Return a url where the core algorithm for algorithm
is documented.
Returns "unknown"
if the algorithm implementation has failed to overload the trait.
New implementations
Must return a string, such as "https://en.wikipedia.org/wiki/Decision_tree_learning"
.
LearnAPI.load_path
— FunctionLearnAPI.load_path(algorithm)
Return a string indicating where the struct
for typeof(algorithm)
can be found, beginning with the name of the package module defining it. For example, a return value of "FastTrees.LearnAPI.DecisionTreeClassifier"
means the following julia code will return the algorithm type:
import FastTrees
FastTrees.LearnAPI.DecisionTreeClassifier
Returns "unknown"
if the algorithm implementation has failed to overload the trait.
LearnAPI.is_composite
— FunctionLearnAPI.is_composite(algorithm)
Returns true
if one or more properties (fields) of algorithm
may themselves be algorithms, and false
otherwise.
See also [LearnAPI.components]
(@ref).
New implementations
This trait should be overloaded if one or more properties (fields) of algorithm
may take algorithm values. Fallback return value is false
. The keyword constructor for such an algorithm need not prescribe defaults for algorithm-valued properties. Implementation of the accessor function LearnAPI.components
is recommended.
The value of the trait must depend only on the type of algorithm
.
LearnAPI.human_name
— FunctionLearnAPI.human_name(algorithm)
A human-readable string representation of typeof(algorithm)
. Primarily intended for auto-generation of documentation.
New implementations
Optional. A fallback takes the type name, inserts spaces and removes capitalization. For example, KNNRegressor
becomes "knn regressor"
. Better would be to overload the trait to return "K-nearest neighbors regressor"
. Ideally, this is a "concrete" noun like "ridge regressor"
rather than an "abstract" noun like "ridge regression"
.
LearnAPI.iteration_parameter
— FunctionLearnAPI.iteration_parameter(algorithm)
The name of the iteration parameter of algorithm
, or nothing
if the algorithm is not iterative.
New implementations
Implement if algorithm is iterative. Returns a symbol or nothing
.
LearnAPI.fit_scitype
— FunctionLearnAPI.fit_scitype(algorithm)
Return an upper bound on the scitype of data
guaranteed to work when calling fit(algorithm, data...)
.
Specifically, if the return value is S
and ScientificTypes.scitype(data) <: S
, then all the following calls are guaranteed to work:
fit(algorithm, data...)
obsdata = obs(fit, algorithm, data...)
fit(algorithm, Obs(), obsdata)
See also LearnAPI.fit_type
, LearnAPI.fit_observation_scitype
, LearnAPI.fit_observation_type
.
New implementations
Optional. The fallback return value is Union{}
. Ordinarily, at most one of the following should be overloaded for given algorithm LearnAPI.fit_scitype
, LearnAPI.fit_type
, LearnAPI.fit_observation_scitype
, LearnAPI.fit_observation_type
.
LearnAPI.fit_type
— FunctionLearnAPI.fit_type(algorithm)
Return an upper bound on the type of data
guaranteed to work when calling fit(algorithm, data...)
.
Specifically, if the return value is T
and typeof(data) <: T
, then all the following calls are guaranteed to work:
fit(algorithm, data...)
obsdata = obs(fit, algorithm, data...)
fit(algorithm, Obs(), obsdata)
See also LearnAPI.fit_scitype
, LearnAPI.fit_observation_type
. LearnAPI.fit_observation_scitype
New implementations
Optional. The fallback return value is Union{}
. Ordinarily, at most one of the following should be overloaded for given algorithm LearnAPI.fit_scitype
, LearnAPI.fit_type
, LearnAPI.fit_observation_scitype
, LearnAPI.fit_observation_type
.
LearnAPI.fit_observation_scitype
— FunctionLearnAPI.fit_observation_scitype(algorithm)
Return an upper bound on the scitype of observations guaranteed to work when calling fit(algorithm, data...)
, independent of the type/scitype of the data container itself. Here "observations" is in the sense of MLUtils.jl. Assuming this trait has value different from Union{}
the understanding is that data
implements the MLUtils.jl getobs
/numobs
interface.
Specifically, denoting the type returned above by S
, supposing S != Union{}
, and that user supplies data
satisfying
ScientificTypes.scitype(MLUtils.getobs(data, i)) <: S
for any valid index i
, then all the following are guaranteed to work:
fit(algorithm, data....)
obsdata = obs(fit, algorithm, data...)
fit(algorithm, Obs(), obsdata)
See also See also LearnAPI.fit_type
, LearnAPI.fit_scitype
, LearnAPI.fit_observation_type
.
New implementations
Optional. The fallback return value is Union{}
. Ordinarily, at most one of the following should be overloaded for given algorithm LearnAPI.fit_scitype
, LearnAPI.fit_type
, LearnAPI.fit_observation_scitype
, LearnAPI.fit_observation_type
.
LearnAPI.fit_observation_type
— FunctionLearnAPI.fit_observation_type(algorithm)
Return an upper bound on the type of observations guaranteed to work when calling fit(algorithm, data...)
, independent of the type/scitype of the data container itself. Here "observations" is in the sense of MLUtils.jl. Assuming this trait has value different from Union{}
the understanding is that data
implements the MLUtils.jl getobs
/numobs
interface.
Specifically, denoting the type returned above by T
, supposing T != Union{}
, and that user supplies data
satisfying
typeof(MLUtils.getobs(data, i)) <: T
for any valid index i
, then the following is guaranteed to work:
fit(algorithm, data....)
obsdata = obs(fit, algorithm, data...)
fit(algorithm, Obs(), obsdata)
See also See also LearnAPI.fit_type
, LearnAPI.fit_scitype
, LearnAPI.fit_observation_scitype
.
New implementations
Optional. The fallback return value is Union{}
. Ordinarily, at most one of the following should be overloaded for given algorithm LearnAPI.fit_scitype
, LearnAPI.fit_type
, LearnAPI.fit_observation_scitype
, LearnAPI.fit_observation_type
.
LearnAPI.predict_input_scitype
— Function LearnAPI.predict_input_scitype(algorithm)
Return an upper bound on the scitype of data
guaranteed to work in the call predict(algorithm, kind_of_proxy, data...)
.
Specifically, if S
is the value returned and ScientificTypes.scitype(data) <: S
, then the following is guaranteed to work:
julia predict(model, kind_of_proxy, data...) obsdata = obs(predict, algorithm, data...) predict(model, kind_of_proxy, Obs(), obsdata)
whenever algorithm = LearnAPI.algorithm(model)
.
See also LearnAPI.predict_input_type
.
New implementations
Implementation is optional. The fallback return value is Union{}
. Ordinarily, at most one of the following should be overloaded for given algorithm LearnAPI.predict_scitype
, LearnAPI.predict_type
, LearnAPI.predict_observation_scitype
, LearnAPI.predict_observation_type
.
LearnAPI.predict_input_observation_scitype
— FunctionLearnAPI.predict_observation_scitype(algorithm)
Return an upper bound on the scitype of observations guaranteed to work when calling predict(model, kind_of_proxy, data...)
, independent of the type/scitype of the data container itself. Here "observations" is in the sense of MLUtils.jl. Assuming this trait has value different from Union{}
the understanding is that data
implements the MLUtils.jl getobs
/numobs
interface.
Specifically, denoting the type returned above by S
, supposing S != Union{}
, and that user supplies data
satisfying
ScientificTypes.scitype(MLUtils.getobs(data, i)) <: S
for any valid index i
, then all the following are guaranteed to work:
predict(model, kind_of_proxy, data...)
obsdata = obs(predict, algorithm, data...)
predict(model, kind_of_proxy, Obs(), obsdata)
whenever algorithm = LearnAPI.algorithm(model)
.
See also See also LearnAPI.fit_type
, LearnAPI.fit_scitype
, LearnAPI.fit_observation_type
.
New implementations
Optional. The fallback return value is Union{}
. Ordinarily, at most one of the following should be overloaded for given algorithm LearnAPI.predict_scitype
, LearnAPI.predict_type
, LearnAPI.predict_observation_scitype
, LearnAPI.predict_observation_type
.
LearnAPI.predict_input_type
— FunctionLearnAPI.predict_input_type(algorithm)
Return an upper bound on the type of data
guaranteed to work in the call predict(algorithm, kind_of_proxy, data...)
.
Specifically, if T
is the value returned and typeof(data) <: T
, then the following is guaranteed to work:
predict(model, kind_of_proxy, data...)
obsdata = obs(predict, model, data...)
predict(model, kind_of_proxy, Obs(), obsdata)
See also LearnAPI.predict_input_scitype
.
New implementations
Implementation is optional. The fallback return value is Union{}
. Should not be overloaded if LearnAPI.predict_input_scitype
is overloaded.
LearnAPI.predict_input_observation_type
— FunctionLearnAPI.predict_observation_type(algorithm)
Return an upper bound on the type of observations guaranteed to work when calling predict(model, kind_of_proxy, data...)
, independent of the type/scitype of the data container itself. Here "observations" is in the sense of MLUtils.jl. Assuming this trait has value different from Union{}
the understanding is that data
implements the MLUtils.jl getobs
/numobs
interface.
Specifically, denoting the type returned above by T
, supposing T != Union{}
, and that user supplies data
satisfying
typeof(MLUtils.getobs(data, i)) <: T
for any valid index i
, then all the following are guaranteed to work:
predict(model, kind_of_proxy, data...)
obsdata = obs(predict, algorithm, data...)
predict(model, kind_of_proxy, Obs(), obsdata)
whenever algorithm = LearnAPI.algorithm(model)
.
See also See also LearnAPI.fit_type
, LearnAPI.fit_scitype
, LearnAPI.fit_observation_type
.
New implementations
Optional. The fallback return value is Union{}
. Ordinarily, at most one of the following should be overloaded for given algorithm LearnAPI.predict_scitype
, LearnAPI.predict_type
, LearnAPI.predict_observation_scitype
, LearnAPI.predict_observation_type
.
LearnAPI.predict_output_scitype
— FunctionLearnAPI.predict_output_scitype(algorithm, kind_of_proxy::KindOfProxy)
Return an upper bound for the scitypes of predictions of the specified form where supported, and otherwise return Any
. For example, if
ŷ = LearnAPI.predict(model, LearnAPI.Distribution(), data...)
successfully returns (i.e., algorithm
supports predictions of target probability distributions) then the following is guaranteed to hold:
scitype(ŷ) <: LearnAPI.predict_output_scitype(algorithm, LearnAPI.Distribution())
Note. This trait has a single-argument "convenience" version LearnAPI.predict_output_scitype(algorithm)
derived from this one, which returns a dictionary keyed on target proxy types.
See also LearnAPI.KindOfProxy
, LearnAPI.predict
, LearnAPI.predict_input_scitype
.
New implementations
Overloading the trait is optional. Here's a sample implementation for a supervised regressor type MyRgs
that only predicts actual values of the target:
@trait MyRgs predict_output_scitype = AbstractVector{ScientificTypesBase.Continuous}
The fallback method returns Any
.
LearnAPI.predict_output_scitype(algorithm)
Return a dictionary of upper bounds on the scitype of predictions, keyed on concrete subtypes of LearnAPI.KindOfProxy
. Each of these subtypes represents a different form of target prediction (LiteralTarget
, Distribution
, SurvivalFunction
, etc) possibly supported by algorithm
, but the existence of a key does not guarantee that form is supported.
As an example, if
ŷ = LearnAPI.predict(model, LearnAPI.Distribution(), data...)
successfully returns (i.e., algorithm
supports predictions of target probability distributions) then the following is guaranteed to hold:
scitype(ŷ) <: LearnAPI.predict_output_scitypes(algorithm)[LearnAPI.Distribution]
See also LearnAPI.KindOfProxy
, LearnAPI.predict
, LearnAPI.predict_input_scitype
.
New implementations
This single argument trait should not be overloaded. Instead, overload LearnAPI.predict_output_scitype
(algorithm, kindofproxy).
LearnAPI.predict_output_type
— FunctionLearnAPI.predict_output_type(algorithm, kind_of_proxy::KindOfProxy)
Return an upper bound for the types of predictions of the specified form where supported, and otherwise return Any
. For example, if
ŷ = LearnAPI.predict(model, LearnAPI.Distribution(), data...)
successfully returns (i.e., algorithm
supports predictions of target probability distributions) then the following is guaranteed to hold:
type(ŷ) <: LearnAPI.predict_output_type(algorithm, LearnAPI.Distribution())
Note. This trait has a single-argument "convenience" version LearnAPI.predict_output_type(algorithm)
derived from this one, which returns a dictionary keyed on target proxy types.
See also LearnAPI.KindOfProxy
, LearnAPI.predict
, LearnAPI.predict_input_type
.
New implementations
Overloading the trait is optional. Here's a sample implementation for a supervised regressor type MyRgs
that only predicts actual values of the target:
@trait MyRgs predict_output_type = AbstractVector{ScientificTypesBase.Continuous}
The fallback method returns Any
.
LearnAPI.predict_output_type(algorithm)
Return a dictionary of upper bounds on the type of predictions, keyed on concrete subtypes of LearnAPI.KindOfProxy
. Each of these subtypes represents a different form of target prediction (LiteralTarget
, Distribution
, SurvivalFunction
, etc) possibly supported by algorithm
, but the existence of a key does not guarantee that form is supported.
As an example, if
ŷ = LearnAPI.predict(model, LearnAPI.Distribution(), data...)
successfully returns (i.e., algorithm
supports predictions of target probability distributions) then the following is guaranteed to hold:
type(ŷ) <: LearnAPI.predict_output_types(algorithm)[LearnAPI.Distribution]
See also LearnAPI.KindOfProxy
, LearnAPI.predict
, LearnAPI.predict_input_type
.
New implementations
This single argument trait should not be overloaded. Instead, overload LearnAPI.predict_output_type
(algorithm, kindofproxy).
LearnAPI.transform_input_scitype
— Function LearnAPI.transform_input_scitype(algorithm)
Return an upper bound on the scitype of data
guaranteed to work in the call transform(algorithm, data...)
.
Specifically, if S
is the value returned and ScientificTypes.scitype(data) <: S
, then the following is guaranteed to work:
julia transform(model, data...) obsdata = obs(transform, algorithm, data...) transform(model, Obs(), obsdata)
whenever algorithm = LearnAPI.algorithm(model)
.
See also LearnAPI.transform_input_type
.
New implementations
Implementation is optional. The fallback return value is Union{}
. Ordinarily, at most one of the following should be overloaded for given algorithm LearnAPI.transform_scitype
, LearnAPI.transform_type
, LearnAPI.transform_observation_scitype
, LearnAPI.transform_observation_type
.
LearnAPI.transform_input_observation_scitype
— FunctionLearnAPI.transform_observation_scitype(algorithm)
Return an upper bound on the scitype of observations guaranteed to work when calling transform(model, data...)
, independent of the type/scitype of the data container itself. Here "observations" is in the sense of MLUtils.jl. Assuming this trait has value different from Union{}
the understanding is that data
implements the MLUtils.jl getobs
/numobs
interface.
Specifically, denoting the type returned above by S
, supposing S != Union{}
, and that user supplies data
satisfying
ScientificTypes.scitype(MLUtils.getobs(data, i)) <: S
for any valid index i
, then all the following are guaranteed to work:
transform(model, data...)
obsdata = obs(transform, algorithm, data...)
transform(model, Obs(), obsdata)
whenever algorithm = LearnAPI.algorithm(model)
.
See also See also LearnAPI.fit_type
, LearnAPI.fit_scitype
, LearnAPI.fit_observation_type
.
New implementations
Optional. The fallback return value is Union{}
. Ordinarily, at most one of the following should be overloaded for given algorithm LearnAPI.transform_scitype
, LearnAPI.transform_type
, LearnAPI.transform_observation_scitype
, LearnAPI.transform_observation_type
.
LearnAPI.transform_input_type
— FunctionLearnAPI.transform_input_type(algorithm)
Return an upper bound on the type of data
guaranteed to work in the call transform(algorithm, data...)
.
Specifically, if T
is the value returned and typeof(data) <: T
, then the following is guaranteed to work:
transform(model, data...)
obsdata = obs(transform, model, data...)
transform(model, Obs(), obsdata)
See also LearnAPI.transform_input_scitype
.
New implementations
Implementation is optional. The fallback return value is Union{}
. Should not be overloaded if LearnAPI.transform_input_scitype
is overloaded.
LearnAPI.transform_input_observation_type
— FunctionLearnAPI.transform_observation_type(algorithm)
Return an upper bound on the type of observations guaranteed to work when calling transform(model, data...)
, independent of the type/scitype of the data container itself. Here "observations" is in the sense of MLUtils.jl. Assuming this trait has value different from Union{}
the understanding is that data
implements the MLUtils.jl getobs
/numobs
interface.
Specifically, denoting the type returned above by T
, supposing T != Union{}
, and that user supplies data
satisfying
typeof(MLUtils.getobs(data, i)) <: T
for any valid index i
, then all the following are guaranteed to work:
transform(model, data...)
obsdata = obs(transform, algorithm, data...)
transform(model, Obs(), obsdata)
whenever algorithm = LearnAPI.algorithm(model)
.
See also See also LearnAPI.fit_type
, LearnAPI.fit_scitype
, LearnAPI.fit_observation_type
.
New implementations
Optional. The fallback return value is Union{}
. Ordinarily, at most one of the following should be overloaded for given algorithm LearnAPI.transform_scitype
, LearnAPI.transform_type
, LearnAPI.transform_observation_scitype
, LearnAPI.transform_observation_type
.
LearnAPI.predict_or_transform_mutates
— FunctionLearnAPI.predict_or_transform_mutates(algorithm)
Returns true
if predict
or transform
possibly mutate their first argument, model
, when LearnAPI.algorithm(model) == algorithm
. If false
, no arguments are ever mutated.
New implementations
This trait, falling back to false
, may only be overloaded when fit
has no data arguments (algorithm
does not generalize to new data). See more at fit
.
LearnAPI.transform_output_scitype
— FunctionLearnAPI.transform_output_scitype(algorithm)
Return an upper bound on the scitype of the output of the transform
operation.
See also LearnAPI.transform_input_scitype
.
New implementations
Implementation is optional. The fallback return value is Any
.
LearnAPI.transform_output_type
— FunctionLearnAPI.transform_output_type(algorithm)
Return an upper bound on the type of the output of the transform
operation.
New implementations
Implementation is optional. The fallback return value is Any
.