A base Julia interface for machine learning and statistics
LearnAPI.jl is a lightweight, functional-style interface, providing a collection of methods, such as fit
and predict
, to be implemented by algorithms from machine learning and statistics. Through such implementations, these algorithms buy into functionality, such as hyperparameter optimization, as provided by ML/statistics toolboxes and other packages. LearnAPI.jl also provides a number of Julia traits for promising specific behavior.
The API described here is under active development and not ready for adoption. Join an ongoing design discussion at this Julia Discourse thread.
Sample workflow
Suppose forest
is some object encapsulating the hyperparameters of the random forest algorithm (the number of trees, etc.). Then, a LearnAPI.jl interface can be implemented, for objects with the type of forest
, to enable the following basic workflow:
X = <some training features>
y = <some training target>
Xnew = <some test or production features>
# Train:
model = fit(forest, X, y)
# Predict probability distributions:
predict(model, Distribution(), Xnew)
# Generate point predictions:
ŷ = predict(model, LiteralTarget(), Xnew) # or `predict(model, Xnew)`
# Apply an "accessor function" to inspect byproducts of training:
LearnAPI.feature_importances(model)
# Slim down and otherwise prepare model for serialization:
small_model = minimize(model)
serialize("my_random_forest.jls", small_model)
# Recover saved model and algorithm configuration:
recovered_model = deserialize("my_random_forest.jls")
@assert LearnAPI.algorithm(recovered_model) == forest
@assert predict(recovered_model, LiteralTarget(), Xnew) == ŷ
Distribution
and LiteralTarget
are singleton types owned by LearnAPI.jl. They allow dispatch based on the kind of target proxy, a key LearnAPI.jl concept. LearnAPI.jl places more emphasis on the notion of target variables and target proxies than on the usual supervised/unsupervised learning dichotomy. From this point of view, a supervised algorithm is simply one in which a target variable exists, and happens to appear as an input to training but not to prediction.
In LearnAPI.jl, a method called obs
gives users access to an "internal", algorithm-specific, representation of input data, which is always "observation-accessible", in the sense that it can be resampled using MLUtils.jl getobs/numobs
interface. The implementation can arrange for this resampling to be efficient, and workflows based on obs
can have performance benefits.
Learning more
Anatomy of an Implementation: informal introduction to the main actors in a new LearnAPI.jl implementation
Reference: official specification
Common Implementation Patterns: implementation suggestions for common, informally defined, algorithm types