Learner Traits

Learner traits are simply functions whose sole argument is a learner.

Traits promise specific learner behavior, such as: This learner can make point or probabilistic predictions or This learner is supervised (sees a target in training). They may also record more mundane information, such as a package license.

Trait summary

Overloadable traits

In the examples column of the table below, Continuous is a name owned the package ScientificTypesBase.jl.

traitreturn valuefallback valueexample
LearnAPI.constructor(learner)constructor for generating new or modified versions of learner(no fallback)RidgeRegressor
LearnAPI.functions(learner)functions you can apply to learner or associated model (traits excluded)()(:fit, :predict, :LearnAPI.strip, :(LearnAPI.learner), :obs)
LearnAPI.kinds_of_proxy(learner)instances kind of KindOfProxy for which an implementation of LearnAPI.predict(learner, kind, ...) is guaranteed.()(Distribution(), Interval())
LearnAPI.tags(learner)lists one or more suggestive learner tags from LearnAPI.tags()()(:regression, :probabilistic)
LearnAPI.is_pure_julia(learner)true if implementation is 100% Julia codefalsetrue
LearnAPI.pkg_name(learner)name of package providing core code (may be different from package providing LearnAPI.jl implementation)"unknown""DecisionTree"
LearnAPI.pkg_license(learner)name of license of package providing core code"unknown""MIT"
LearnAPI.doc_url(learner)url providing documentation of the core code"unknown""https://en.wikipedia.org/wiki/Decision_tree_learning"
LearnAPI.load_path(learner)string locating name returned by LearnAPI.constructor(learner), beginning with a package name"unknown"`FastTrees.LearnAPI.DecisionTreeClassifier
LearnAPI.is_composite(learner)true if one or more properties of learner may be a learnerfalsetrue
LearnAPI.human_name(learner)human name for the learner; should be a nountype name with spaces"elastic net regressor"
LearnAPI.iteration_parameter(learner)symbolic name of an iteration parameternothing:epochs
LearnAPI.data_interface(learner)Interface implemented by objects returned by obsBase.HasLength() (supports MLUtils.getobs/numobs)Base.SizeUnknown() (supports iterate)
LearnAPI.fit_observation_scitype(learner)upper bound on scitype(observation) for observation in data ensuring fit(learner, data) worksUnion{}Tuple{AbstractVector{Continuous}, Continuous}
LearnAPI.target_observation_scitype(learner)upper bound on the scitype of each observation of the targgetAnyContinuous
LearnAPI.is_static(learner)true if fit consumes no datafalsetrue

Derived Traits

The following are provided for convenience but should not be overloaded by new learners:

traitreturn valueexample
LearnAPI.name(learner)learner type name as string"PCA"
LearnAPI.is_learner(learner)true if learner is LearnAPI.jl-complianttrue
LearnAPI.target(learner)true if fit sees a target variable; see LearnAPI.targetfalse
LearnAPI.weights(learner)true if fit supports per-observation; see LearnAPI.weightsfalse

Implementation guide

A single-argument trait is declared following this pattern:

LearnAPI.is_pure_julia(learner::MyLearnerType) = true

A shorthand for single-argument traits is available:

@trait MyLearnerType is_pure_julia=true

Multiple traits can be declared like this:

@trait(
    MyLearnerType,
    is_pure_julia = true,
    pkg_name = "MyPackage",
)

The global trait contract

To ensure that trait metadata can be stored in an external learner registry, LearnAPI.jl requires:

  1. Finiteness: The value of a trait is the same for all learners with same value of LearnAPI.constructor(learner). This typically means trait values do not depend on type parameters! If is_composite(learner) = true, this requirement is dropped.

  2. Low level deserializability: It should be possible to evaluate the trait value when LearnAPI is the only imported module.

Because of 1, combining a lot of functionality into one learner (e.g. the learner can perform both classification or regression) can mean traits are necessarily less informative (as in LearnAPI.target_observation_scitype(learner) = Any).

Reference

LearnAPI.constructorFunction
Learn.API.constructor(learner)

Return a keyword constructor that can be used to clone learner:

julia> learner.lambda
0.1
julia> C = LearnAPI.constructor(learner)
julia> learner2 = C(lambda=0.2)
julia> learner2.lambda
0.2

New implementations

All new implementations must overload this trait.

Attach public LearnAPI.jl-related documentation for learner to the constructor, not the learner struct.

It must be possible to recover learner from the constructor returned as follows:

properties = propertynames(learner)
named_properties = NamedTuple{properties}(getproperty.(Ref(learner), properties))
@assert learner == LearnAPI.constructor(learner)(; named_properties...)

which can be tested with @assert LearnAPI.clone(learner) == learner.

The keyword constructor provided by LearnAPI.constructor must provide default values for all properties, with the exception of those that can take other LearnAPI.jl learners as values. These can be provided with the default nothing, with the constructor throwing an error if the default value persists.

source
LearnAPI.functionsFunction
LearnAPI.functions(learner)

Return a tuple of expressions representing functions that can be meaningfully applied with learner, or an associated model (object returned by fit(learner, ...), as the first argument. Learner traits (methods for which learner is the only argument) are excluded.

To return actual functions, instead of symbols, use @functionslearner instead.

The returned tuple may include expressions like :(DecisionTree.print_tree), which reference functions not owned by LearnAPI.jl.

The understanding is that learner is a LearnAPI-compliant object whenever the return value is non-empty.

Do LearnAPI.functions() to list all possible elements of the return value owned by LearnAPI.jl.

Extended help

New implementations

All new implementations must implement this trait. Here's a checklist for elements in the return value:

expressionimplementation compulsory?include in returned tuple?
:(LearnAPI.fit)yesyes
:(LearnAPI.learner)yesyes
:(LearnAPI.strip)noyes
:(LearnAPI.obs)noyes
:(LearnAPI.features)noyes, unless fit consumes no data
:(LearnAPI.target)noonly if implemented
:(LearnAPI.weights)noonly if implemented
:(LearnAPI.update)noonly if implemented
:(LearnAPI.update_observations)noonly if implemented
:(LearnAPI.update_features)noonly if implemented
:(LearnAPI.predict)noonly if implemented
:(LearnAPI.transform)noonly if implemented
:(LearnAPI.inverse_transform)noonly if implemented
< accessor functions>noonly if implemented

Also include any implemented accessor functions, both those owned by LearnaAPI.jl, and any learner-specific ones. The LearnAPI.jl accessor functions are: LearnAPI.extras, LearnAPI.learner, LearnAPI.coefficients, LearnAPI.intercept, LearnAPI.tree, LearnAPI.trees, LearnAPI.feature_names, LearnAPI.feature_importances, LearnAPI.training_labels, LearnAPI.training_losses, LearnAPI.training_predictions, LearnAPI.training_scores and LearnAPI.components (LearnAPI.strip is always included).

source
LearnAPI.kinds_of_proxyFunction
LearnAPI.kinds_of_proxy(learner)

Returns a tuple of all instances, kind, for which for which predict(learner, kind, data...) has a guaranteed implementation. Each such kind subtypes LearnAPI.KindOfProxy. Examples are Point() (for predicting actual target values) and Distributions() (for predicting probability mass/density functions).

The call predict(model, data) always returns predict(model, kind, data), where kind is the first element of the trait's return value.

See also LearnAPI.predict, LearnAPI.KindOfProxy.

Extended help

New implementations

Must be overloaded whenever predict is implemented.

Elements of the returned tuple must be instances of LearnAPI.KindOfProxy. List all possibilities by running LearnAPI.kinds_of_proxy().

Suppose, for example, we have the following implementation of a supervised learner returning only probabilistic predictions:

LearnAPI.predict(learner::MyNewLearnerType, LearnAPI.Distribution(), Xnew) = ...

Then we can declare

@trait MyNewLearnerType kinds_of_proxy = (LearnaAPI.Distribution(),)

LearnAPI.jl provides the fallback for predict(model, data).

For more on target variables and target proxies, refer to the LearnAPI documentation.

source
LearnAPI.tagsFunction
LearnAPI.tags(learner)

Lists one or more suggestive learner tags. Do LearnAPI.tags() to list all possible.

Warning

The value of this trait guarantees no particular behavior. The trait is intended for informal classification purposes only.

New implementations

This trait should return a tuple of strings, as in ("classifier", "text analysis").

source
LearnAPI.is_pure_juliaFunction
LearnAPI.is_pure_julia(learner)

Returns true if training learner requires evaluation of pure Julia code only.

New implementations

The fallback is false.

source
LearnAPI.pkg_nameFunction
LearnAPI.pkg_name(learner)

Return the name of the package module which supplies the core training algorithm for learner. This is not necessarily the package providing the LearnAPI interface.

Returns "unknown" if the learner implementation has not overloaded the trait.

New implementations

Must return a string, as in "DecisionTree".

source
LearnAPI.pkg_licenseFunction
LearnAPI.pkg_license(learner)

Return the name of the software license, such as "MIT", applying to the package where the core algorithm for learner is implemented.

source
LearnAPI.doc_urlFunction
LearnAPI.doc_url(learner)

Return a url where the core algorithm for learner is documented.

Returns "unknown" if the learner implementation has not overloaded the trait.

New implementations

Must return a string, such as "https://en.wikipedia.org/wiki/Decision_tree_learning".

source
LearnAPI.load_pathFunction
LearnAPI.load_path(learner)

Return a string indicating where in code the definition of the learner's constructor can be found, beginning with the name of the package module defining it. By "constructor" we mean the return value of LearnAPI.constructor(learner).

Implementation

For example, a return value of "FastTrees.LearnAPI.DecisionTreeClassifier" means the following julia code will not error:

import FastTrees
import LearnAPI
@assert FastTrees.LearnAPI.DecisionTreeClassifier == LearnAPI.constructor(learner)

Returns "unknown" if the learner implementation has not overloaded the trait.

source
LearnAPI.is_compositeFunction
LearnAPI.is_composite(learner)

Returns true if one or more properties (fields) of learner may themselves be learners, and false otherwise.

See also LearnAPI.components.

New implementations

This trait should be overloaded if one or more properties (fields) of learner may take learner values. Fallback return value is false. The keyword constructor for such an learner need not prescribe defaults for learner-valued properties. Implementation of the accessor function LearnAPI.components is recommended.

The value of the trait must depend only on the type of learner.

source
LearnAPI.human_nameFunction
LearnAPI.human_name(learner)

Return a human-readable string representation of typeof(learner). Primarily intended for auto-generation of documentation.

New implementations

Optional. A fallback takes the type name, inserts spaces and removes capitalization. For example, KNNRegressor becomes "knn regressor". Better would be to overload the trait to return "K-nearest neighbors regressor". Ideally, this is a "concrete" noun like "ridge regressor" rather than an "abstract" noun like "ridge regression".

source
LearnAPI.data_interfaceFunction
LearnAPI.data_interface(learner)

Return the data interface supported by learner for accessing individual observations in representations of input data returned by obs(learner, data) or obs(model, data), whenever learner == LearnAPI.learner(model). Here data is fit, predict, or transform-consumable data.

Possible return values are LearnAPI.RandomAccess, LearnAPI.FiniteIterable, and LearnAPI.Iterable.

See also obs.

New implementations

The fallback returns LearnAPI.RandomAccess, which applies to arrays, most tables, and tuples of these. See the doc-string for details.

source
LearnAPI.iteration_parameterFunction
LearnAPI.iteration_parameter(learner)

The name of the iteration parameter of learner, or nothing if the algorithm is not iterative.

New implementations

Implement if algorithm is iterative. Returns a symbol or nothing.

source
LearnAPI.fit_observation_scitypeFunction
LearnAPI.fit_observation_scitype(learner)

Return an upper bound S on the scitype of individual observations guaranteed to work when calling fit: if observations = obs(learner, data) and ScientificTypes.scitype(collect(o)) <:S for each o in observations, then the call fit(learner, data) is supported.

Here, "for each o in observations" is understood in the sense of LearnAPI.data_interface(learner). For example, if LearnAPI.data_interface(learner) == Base.HasLength(), then this means "for o in MLUtils.eachobs(observations)".

See also LearnAPI.target_observation_scitype.

New implementations

Optional. The fallback return value is Union{}.

source
LearnAPI.target_observation_scitypeFunction
LearnAPI.target_observation_scitype(learner)

Return an upper bound S on the scitype of each observation of an applicable target variable. Specifically:

  • If :(LearnAPI.target) in LearnAPI.functions(learner) (i.e., fit consumes target variables) then "target" means anything returned by LearnAPI.target(learner, data), where data is an admissible argument in the call fit(learner, data).

  • S will always be an upper bound on the scitype of (point) observations that could be conceivably extracted from the output of predict.

To illustate the second case, suppose we have

model = fit(learner, data)
ŷ = predict(model, Sampleable(), data_new)

Then each individual sample generated by each "observation" of (a vector of sampleable objects, say) will be bound in scitype by S.

See also See also LearnAPI.fit_observation_scitype.

New implementations

Optional. The fallback return value is Any.

source
LearnAPI.is_staticFunction
LearnAPI.is_static(learner)

Returns true if fit is called with no data arguments, as in fit(learner). That is, learner does not generalize to new data, and data is only provided at the predict or transform step.

For example, some clustering algorithms are applied with this workflow, to assign labels to the observations in X:

model = fit(learner) # no training data
labels = predict(model, X) # may mutate `model`!

# extract some byproducts of the clustering algorithm (e.g., outliers):
LearnAPI.extras(model)

New implementations

This trait, falling back to false, may only be overloaded when fit has no data arguments. See more at fit.

source