LaplaceClassifier

LaplaceClassifier

A model type for constructing a laplace classifier, based on LaplaceRedux.jl, and implementing the MLJ model interface.

From MLJ, the type can be imported using

LaplaceClassifier = @load LaplaceClassifier pkg=LaplaceRedux

Do model = LaplaceClassifier() to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in LaplaceClassifier(model=...).

LaplaceClassifier implements the Laplace Redux – Effortless Bayesian Deep Learning, originally published in Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., Hennig, P. (2021): "Laplace Redux – Effortless Bayesian Deep Learning.", NIPS'21: Proceedings of the 35th International Conference on Neural Information Processing Systems*, Article No. 1537, pp. 20089–20103 for classification models.

Training data

In MLJ or MLJBase, given a dataset X,y and a Flux_Chain adapted to the dataset, pass the chain to the model

laplace_model = LaplaceClassifier(model = Flux_Chain,kwargs...)

then bind an instance laplace_model to data with

mach = machine(laplace_model, X, y)

where

  • X: any table of input features (eg, a DataFrame) whose columns each have one of the following element scitypes: Continuous, Count, or <:OrderedFactor; check column scitypes with schema(X)
  • y: is the target, which can be any AbstractVector whose element scitype is <:OrderedFactor or <:Multiclass; check the scitype with scitype(y)

Train the machine using fit!(mach, rows=...).

Hyperparameters (format: name-type-default value-restrictions)

  • model::Union{Flux.Chain,Nothing} = nothing: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 2 hidden layers with 20 neurons each.
  • flux_loss = Flux.Losses.logitcrossentropy : a Flux loss function
  • optimiser = Adam() a Flux optimiser
  • epochs::Integer = 1000::(_ > 0): the number of training epochs.
  • batch_size::Integer = 32::(_ > 0): the batch size.
  • subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)): the subset of weights to use, either :all, :last_layer, or :subnetwork.
  • subnetwork_indices = nothing: the indices of the subnetworks.
  • hessian_structure::Union{HessianStructure,Symbol,String} = :full::(_ in (:full, :diagonal)): the structure of the Hessian matrix, either :full or :diagonal.
  • backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher)): the backend to use, either :GGN or :EmpiricalFisher.
  • observational_noise (alias σ)::Float64 = 1.0: the standard deviation of the prior distribution.
  • prior_mean (alias μ₀)::Float64 = 0.0: the mean of the prior distribution.
  • prior_precision_matrix (alias P₀)::Union{AbstractMatrix,UniformScaling,Nothing} = nothing: the covariance matrix of the prior distribution.
  • fit_prior_nsteps::Int = 100::(_ > 0): the number of steps used to fit the priors.
  • link_approx::Symbol = :probit::(_ in (:probit, :plugin)): the approximation to adopt to compute the probabilities.

Operations

  • predict(mach, Xnew): return predictions of the target given features Xnew having the same scitype as X above. Predictions are probabilistic, but uncalibrated.
  • predict_mode(mach, Xnew): instead return the mode of each prediction above.

Fitted parameters

The fields of fitted_params(mach) are:

  • mean: The mean of the posterior distribution.
  • H: The Hessian of the posterior distribution.
  • P: The precision matrix of the posterior distribution.
  • cov_matrix: The covariance matrix of the posterior distribution.
  • n_data: The number of data points.
  • n_params: The number of parameters.
  • n_out: The number of outputs.
  • loss: The loss value of the posterior distribution.

Report

The fields of report(mach) are:

  • loss_history: an array containing the total loss per epoch.

Accessor functions

  • training_losses(mach): return the loss history from report

Examples

using MLJ
LaplaceClassifier = @load LaplaceClassifier pkg=LaplaceRedux

X, y = @load_iris

## Define the Flux Chain model
using Flux
model = Chain(
    Dense(4, 10, relu),
    Dense(10, 10, relu),
    Dense(10, 3)
)

#Define the LaplaceClassifier
model = LaplaceClassifier(model=model)

mach = machine(model, X, y) |> fit!

Xnew = (sepal_length = [6.4, 7.2, 7.4],
        sepal_width = [2.8, 3.0, 2.8],
        petal_length = [5.6, 5.8, 6.1],
        petal_width = [2.1, 1.6, 1.9],)
yhat = predict(mach, Xnew) ## probabilistic predictions
predict_mode(mach, Xnew)   ## point predictions
training_losses(mach)      ## loss history per epoch
pdf.(yhat, "virginica")    ## probabilities for the "verginica" class
fitted_params(mach)        ## NamedTuple with the fitted params of Laplace

See also LaplaceRedux.jl.