LaplaceClassifier
LaplaceClassifier
A model type for constructing a laplace classifier, based on LaplaceRedux.jl, and implementing the MLJ model interface.
From MLJ, the type can be imported using
LaplaceClassifier = @load LaplaceClassifier pkg=LaplaceRedux
Do model = LaplaceClassifier()
to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in LaplaceClassifier(model=...)
.
LaplaceClassifier
implements the Laplace Redux – Effortless Bayesian Deep Learning, originally published in Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., Hennig, P. (2021): "Laplace Redux – Effortless Bayesian Deep Learning.", NIPS'21: Proceedings of the 35th International Conference on Neural Information Processing Systems*, Article No. 1537, pp. 20089–20103 for classification models.
Training data
In MLJ or MLJBase, given a dataset X,y and a Flux_Chain
adapted to the dataset, pass the chain to the model
laplace_model = LaplaceClassifier(model = Flux_Chain,kwargs...)
then bind an instance laplace_model
to data with
mach = machine(laplace_model, X, y)
where
X
: any table of input features (eg, aDataFrame
) whose columns each have one of the following element scitypes:Continuous
,Count
, or<:OrderedFactor
; check column scitypes withschema(X)
y
: is the target, which can be anyAbstractVector
whose element scitype is<:OrderedFactor
or<:Multiclass
; check the scitype withscitype(y)
Train the machine using fit!(mach, rows=...)
.
Hyperparameters (format: name-type-default value-restrictions)
model::Union{Flux.Chain,Nothing} = nothing
: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 2 hidden layers with 20 neurons each.flux_loss = Flux.Losses.logitcrossentropy
: a Flux loss functionoptimiser = Adam()
a Flux optimiserepochs::Integer = 1000::(_ > 0)
: the number of training epochs.batch_size::Integer = 32::(_ > 0)
: the batch size.subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
: the subset of weights to use, either:all
,:last_layer
, or:subnetwork
.subnetwork_indices = nothing
: the indices of the subnetworks.hessian_structure::Union{HessianStructure,Symbol,String} = :full::(_ in (:full, :diagonal))
: the structure of the Hessian matrix, either:full
or:diagonal
.backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))
: the backend to use, either:GGN
or:EmpiricalFisher
.observational_noise (alias σ)::Float64 = 1.0
: the standard deviation of the prior distribution.prior_mean (alias μ₀)::Float64 = 0.0
: the mean of the prior distribution.prior_precision_matrix (alias P₀)::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
: the covariance matrix of the prior distribution.fit_prior_nsteps::Int = 100::(_ > 0)
: the number of steps used to fit the priors.link_approx::Symbol = :probit::(_ in (:probit, :plugin))
: the approximation to adopt to compute the probabilities.
Operations
predict(mach, Xnew)
: return predictions of the target given featuresXnew
having the same scitype asX
above. Predictions are probabilistic, but uncalibrated.predict_mode(mach, Xnew)
: instead return the mode of each prediction above.
Fitted parameters
The fields of fitted_params(mach)
are:
mean
: The mean of the posterior distribution.H
: The Hessian of the posterior distribution.P
: The precision matrix of the posterior distribution.cov_matrix
: The covariance matrix of the posterior distribution.n_data
: The number of data points.n_params
: The number of parameters.n_out
: The number of outputs.loss
: The loss value of the posterior distribution.
Report
The fields of report(mach)
are:
loss_history
: an array containing the total loss per epoch.
Accessor functions
training_losses(mach)
: return the loss history from report
Examples
using MLJ
LaplaceClassifier = @load LaplaceClassifier pkg=LaplaceRedux
X, y = @load_iris
## Define the Flux Chain model
using Flux
model = Chain(
Dense(4, 10, relu),
Dense(10, 10, relu),
Dense(10, 3)
)
#Define the LaplaceClassifier
model = LaplaceClassifier(model=model)
mach = machine(model, X, y) |> fit!
Xnew = (sepal_length = [6.4, 7.2, 7.4],
sepal_width = [2.8, 3.0, 2.8],
petal_length = [5.6, 5.8, 6.1],
petal_width = [2.1, 1.6, 1.9],)
yhat = predict(mach, Xnew) ## probabilistic predictions
predict_mode(mach, Xnew) ## point predictions
training_losses(mach) ## loss history per epoch
pdf.(yhat, "virginica") ## probabilities for the "verginica" class
fitted_params(mach) ## NamedTuple with the fitted params of Laplace
See also LaplaceRedux.jl.