Quick Start  |  LearnAPI.jl
LearnDataFrontEnds.jl
Developer tool for adding canned data front ends to LearnAPI.jl implementations

Front ends

LearnDataFrontEndsModule
LearnDataFrontEnds

Module providing the following commonly applicable data front ends for implementations of the LearnAPI.jl interface:

  • Saffron: good for most supervised leaners, typically regressors, operating on structured data

  • Sage: good for most supervised classifiers operating on structured data

  • Tarragon: good for most transformers

See Obs for the corresponding back end API (the interface for the output of obs)

Why add a front end from this package?

  • Users get to specify data in flexible ways: ordinary arrays or most tabular formats supported by Tables.jl. Targets or multitargets can be specified separately, or by column name(s). Standard data preprocessing, such as one-hot encoding and adding higher order feature interactions, can be specified by an R-style "formula", as provided by StatsModels.jl.

  • Developers can focus on core algorithm development, in which data conforms to a standard interface; see Obs.

source

Back end API

LearnDataFrontEnds.ObsType
Obs

Abstract type for all "observations" returned by learners implementing a front end from LearnDataFrontEnds.jl - that is, for any object returned by LearnAPI.obs(learner, data) or LearnAPI.obs(model, data), where learner implements such a front end and model is an object returned by fit(learner, ...).

Any instance, observations, supports the following property access:

  • observations.features: size (p, n) feature matrix (n the number of observations)

  • observations.names: length p vector of feature names (as symbols)

Any instance observations also implements the LearnAPI.RandomAccess interface for accessing individual observations, for purposes of resampling, for example.

Specific to Saffron and Sage

Additionally, when observations = fit(learner, data) and the Saffron(multitarget=...) or Sage(multitarget=...) front end has been implemented, one has:

  • observations.target: length n target vector (multitarget=false) or size (q, n) target matrix (multivariate=true); this array has the same element type as the user-provided one in the Saffron case

Specific to Sage

If Sage(multitarget=..., code_type=...) has been implemented, then observations.target has an integer element type controlled by code_type, and we additionally have:

  • observations.levels: A categorical vector of the ordered target levels, as actually seen in the user-supplied target. The corresponding integer codes will be sort(unique(observations.target)). To get the full pool of levels, apply CategoricalArrays.levels to observations.levels_seen; see the example below.

  • observations.decoder: A callable function that converts an integer code back to the original CategoricalValue it represents.

Pass the first onto predict for making probabilistic predictions, and the second for point predictions; see Sage for details.

Extended help

In the example below, observations implements the full Obs interface described above, for a learner implementing the Sage front end:

using LearnAPI, LearnDataFrontEnds, LearnTestAPI
using CategoricalDistributions, CategoricalArrays, DataFrames
X = DataFrame(rand(10, 3), :auto)
y = categorical(collect("ababababac"))
learner = LearnTestAPI.ConstantClassifier()
observations = obs(learner, (X[1:9,:], y[1:9]))

julia> observations.features
3×9 Matrix{Float64}:
 0.234043  0.526468  0.227417  0.956471    …  0.00587146  0.169291  0.353518  0.402631
 0.631083  0.151317  0.781049  0.00320728     0.756519    0.15317   0.452169  0.127005
 0.285315  0.347433  0.69174   0.516915       0.900343    0.404006  0.448986  0.962649

julia> yint = observations.target
9-element Vector{UInt32}:
 0x00000001
 0x00000002
 0x00000001
 0x00000002
 0x00000001
 0x00000002
 0x00000001
 0x00000002
 0x00000001

julia> observations.levels_seen
2-element CategoricalArray{Char,1,UInt32}:
 'a'
 'b'

julia> sort(unique(observations.target))
2-element Vector{UInt32}:
 0x00000001
 0x00000002

julia> observations.levels_seen |> levels
3-element CategoricalArray{Char,1,UInt32}:
 'a'
 'b'
 'c'

julia> observations.decoder.(yint)
9-element CategoricalArray{Char,1,UInt32}:
 'a'
 'b'
 'a'
 'b'
 'a'
 'b'
 'a'
 'b'
 'a'

julia> d = UnivariateFinite(observations.levels_seen, [0.4, 0.6])
UnivariateFinite{Multiclass{3}}(a=>0.4, b=>0.6)

julia> levels(d)
3-element CategoricalArray{Char,1,UInt32}:
 'a'
 'b'
 'c'
source