A base Julia interface for machine learning and statistics

LearnAPI.jl is a lightweight, functional-style interface, providing a collection of methods, such as fit and predict, to be implemented by algorithms from machine learning and statistics. Through such implementations, these algorithms buy into functionality, such as hyperparameter optimization, as provided by ML/statistics toolboxes and other packages. LearnAPI.jl also provides a number of Julia traits for promising specific behavior.


The API described here is under active development and not ready for adoption. Join an ongoing design discussion at this Julia Discourse thread.

Sample workflow

Suppose forest is some object encapsulating the hyperparameters of the random forest algorithm (the number of trees, etc.). Then, a LearnAPI.jl interface can be implemented, for objects with the type of forest, to enable the following basic workflow:

X = <some training features>
y = <some training target>
Xnew = <some test or production features>

# Train:
model = fit(forest, X, y)

# Predict probability distributions:
predict(model, Distribution(), Xnew)

# Generate point predictions:
ŷ = predict(model, LiteralTarget(), Xnew) # or `predict(model, Xnew)`

# Apply an "accessor function" to inspect byproducts of training:

# Slim down and otherwise prepare model for serialization:
small_model = minimize(model)
serialize("my_random_forest.jls", small_model)

# Recover saved model and algorithm configuration:
recovered_model = deserialize("my_random_forest.jls")
@assert LearnAPI.algorithm(recovered_model) == forest
@assert predict(recovered_model, LiteralTarget(), Xnew) == ŷ

Distribution and LiteralTarget are singleton types owned by LearnAPI.jl. They allow dispatch based on the kind of target proxy, a key LearnAPI.jl concept. LearnAPI.jl places more emphasis on the notion of target variables and target proxies than on the usual supervised/unsupervised learning dichotomy. From this point of view, a supervised algorithm is simply one in which a target variable exists, and happens to appear as an input to training but not to prediction.

In LearnAPI.jl, a method called obs gives users access to an "internal", algorithm-specific, representation of input data, which is always "observation-accessible", in the sense that it can be resampled using MLUtils.jl getobs/numobs interface. The implementation can arrange for this resampling to be efficient, and workflows based on obs can have performance benefits.

Learning more