Tutorial  |  Patterns  |  Reference
LearnAPI.jl
A base Julia interface for machine learning and statistics

LearnAPI.jl is a lightweight, functional-style interface, providing a collection of methods, such as fit and predict, to be implemented by algorithms from machine learning and statistics, some examples of which are listed here. A careful design ensures algorithms implementing LearnAPI.jl can buy into functionality, such as external performance estimates, hyperparameter optimization and model composition, provided by ML/statistics toolboxes and other packages. LearnAPI.jl includes a number of Julia traits for promising specific behavior.

Sample workflow

Suppose forest is some object encapsulating the hyperparameters of the random forest algorithm (the number of trees, etc.). Then, a LearnAPI.jl interface can be implemented, for objects with the type of forest, to enable the basic workflow below. In this case data is presented following the "scikit-learn" X, y pattern, although LearnAPI.jl supports other data patterns.

# `X` is some training features
# `y` is some training target
# `Xnew` is some test or production features

# List LearnaAPI functions implemented for `forest`:
@functions forest

# Train:
model = fit(forest, (X, y))

# Generate point predictions:
ŷ = predict(model, Xnew) # or `predict(model, Point(), Xnew)`

# Predict probability distributions:
predict(model, Distribution(), Xnew)

# Apply an "accessor function" to inspect byproducts of training:
LearnAPI.feature_importances(model)

# Slim down and otherwise prepare model for serialization:
small_model = LearnAPI.strip(model)
serialize("my_random_forest.jls", small_model)

Distribution and Point are singleton types owned by LearnAPI.jl. They allow dispatch based on the kind of target proxy, a key LearnAPI.jl concept. LearnAPI.jl places more emphasis on the notion of target variables and target proxies than on the usual supervised/unsupervised learning dichotomy. From this point of view, a supervised learner is simply one in which a target variable exists, and happens to appear as an input to training but not to prediction.

Data interfaces and front ends

Algorithms are free to consume data in any format. However, this means LearnAPI.jl should provide meta-algorithms, such as cross-validation, some means of subsampling observations, without repeating unnecessarily internal conversions of input data into the form needed by core algorithms. LearnAPI.jl's solution to this problem is to provide a method called obs(learner, data) (read as "observations") which exposes to the user, and whence third party meta-algorithms, a learner-specific, "internal" representation of the "external" data ordinarily supplied to fit (or predict) by the user. For example, data might be a table with mixed column types, but obs(learner, data) consists only of numerical arrays. Unless the implementation opts out, such a representation is additionally guaranteed to implement a standard interface for accessing individual observations, the MLCore.jl getobs/numobs API (previously provided by MLUtils.jl) which is here tagged as LearnAPI.RandomAccess(). These can then be subsampled, without caring about the details of the representation, as in cross-validation. Moreover, such "observations" (sampled or not) can be passed on to fit and predict, instead of the original external form of data. In other words, obs factors out of fit the internal preprocessing of user-supplied data, but in a way that ensures the intercepted, internal form of data implements a standard subsampling API.

Two pathways to generating a model, with and without subsampling. Here obs is provided by an LearnAPI.jl learner implementation, while getobs is a MLCore.jl method for subsampling.

If the input consumed by the algorithm already implements the LearnAPI.RandomAccess() interface (tables, arrays, etc.) then overloading obs is completely optional, as LearnAPI.jl provides a no-operation fallback. Plain iteration interfaces, with or without knowledge of the number of observations, can also be specified, to support, e.g., data loaders reading images from disk.

In the typical case, a new implementation can avoid actually coding data preprocessing by using a canned data front end (implementations of obs). These are provided by the LearnDataFrontEnds.jl package.

Learning more