NeuralNetworkClassifier
mutable struct NeuralNetworkClassifier <: MLJModelInterface.Probabilistic
A simple but flexible Feedforward Neural Network, from the Beta Machine Learning Toolkit (BetaML) for classification problems.
Parameters:
layers
: Array of layer objects [def:nothing
, i.e. basic network]. Seesubtypes(BetaML.AbstractLayer)
for supported layers. The last "softmax" layer is automatically added.loss
: Loss (cost) function [def:BetaML.crossentropy
]. Should always assume y and ŷ as matrices.Warning If you change the parameter
loss
, you need to either provide its derivative on the parameterdloss
or use autodiff withdloss=nothing
.dloss
: Derivative of the loss function [def:BetaML.dcrossentropy
, i.e. the derivative of the cross-entropy]. Usenothing
for autodiff.epochs
: Number of epochs, i.e. passages trough the whole training sample [def:200
]batch_size
: Size of each individual batch [def:16
]opt_alg
: The optimisation algorithm to update the gradient at each batch [def:BetaML.ADAM()
]. Seesubtypes(BetaML.OptimisationAlgorithm)
for supported optimizersshuffle
: Whether to randomly shuffle the data at each iteration (epoch) [def:true
]descr
: An optional title and/or description for this modelcb
: A call back function to provide information during training [def:BetaML.fitting_info
]categories
: The categories to represent as columns. [def:nothing
, i.e. unique training values].handle_unknown
: How to handle categories not seens in training or not present in the providedcategories
array? "error" (default) rises an error, "infrequent" adds a specific column for these categories.other_categories_name
: Which value during prediction to assign to this "other" category (i.e. categories not seen on training or not present in the providedcategories
array? [def:nothing
, i.e. typemax(Int64) for integer vectors and "other" for other types]. This setting is active only ifhandle_unknown="infrequent"
and in that case it MUST be specified if Y is neither integer or stringsrng
: Random Number Generator [deafult:Random.GLOBAL_RNG
]
Notes:
- data must be numerical
- the label should be a n-records by n-dimensions matrix (e.g. a one-hot-encoded data for classification), where the output columns should be interpreted as the probabilities for each categories.
Example:
julia> using MLJ
julia> X, y = @load_iris;
julia> modelType = @load NeuralNetworkClassifier pkg = "BetaML" verbosity=0
BetaML.Nn.NeuralNetworkClassifier
julia> layers = [BetaML.DenseLayer(4,8,f=BetaML.relu),BetaML.DenseLayer(8,8,f=BetaML.relu),BetaML.DenseLayer(8,3,f=BetaML.relu),BetaML.VectorFunctionLayer(3,f=BetaML.softmax)];
julia> model = modelType(layers=layers,opt_alg=BetaML.ADAM())
NeuralNetworkClassifier(
layers = BetaML.Nn.AbstractLayer[BetaML.Nn.DenseLayer([-0.376173352338049 0.7029289511758696 -0.5589563304592478 -0.21043274001651874; 0.044758889527899415 0.6687689636685921 0.4584331114653877 0.6820506583840453; … ; -0.26546358457167507 -0.28469736227283804 -0.164225549922154 -0.516785639164486; -0.5146043550684141 -0.0699113265130964 0.14959906603941908 -0.053706860039406834], [0.7003943613125758, -0.23990840466587576, -0.23823126271387746, 0.4018101580410387, 0.2274483050356888, -0.564975060667734, 0.1732063297031089, 0.11880299829896945], BetaML.Utils.relu, BetaML.Utils.drelu), BetaML.Nn.DenseLayer([-0.029467850439546583 0.4074661266592745 … 0.36775675246760053 -0.595524555448422; 0.42455597698371306 -0.2458082732997091 … -0.3324220683462514 0.44439454998610595; … ; -0.2890883863364267 -0.10109249362508033 … -0.0602680568207582 0.18177278845097555; -0.03432587226449335 -0.4301192922760063 … 0.5646018168286626 0.47269177680892693], [0.13777442835428688, 0.5473306726675433, 0.3781939472904011, 0.24021813428130567, -0.0714779477402877, -0.020386373530818958, 0.5465466618404464, -0.40339790713616525], BetaML.Utils.relu, BetaML.Utils.drelu), BetaML.Nn.DenseLayer([0.6565120540082393 0.7139211611842745 … 0.07809812467915389 -0.49346311403373844; -0.4544472987041656 0.6502667641568863 … 0.43634608676548214 0.7213049952968921; 0.41212264783075303 -0.21993289366360613 … 0.25365007887755064 -0.5664469566269569], [-0.6911986792747682, -0.2149343209329364, -0.6347727539063817], BetaML.Utils.relu, BetaML.Utils.drelu), BetaML.Nn.VectorFunctionLayer{0}(fill(NaN), 3, 3, BetaML.Utils.softmax, BetaML.Utils.dsoftmax, nothing)],
loss = BetaML.Utils.crossentropy,
dloss = BetaML.Utils.dcrossentropy,
epochs = 100,
batch_size = 32,
opt_alg = BetaML.Nn.ADAM(BetaML.Nn.var"#90#93"(), 1.0, 0.9, 0.999, 1.0e-8, BetaML.Nn.Learnable[], BetaML.Nn.Learnable[]),
shuffle = true,
descr = "",
cb = BetaML.Nn.fitting_info,
categories = nothing,
handle_unknown = "error",
other_categories_name = nothing,
rng = Random._GLOBAL_RNG())
julia> mach = machine(model, X, y);
julia> fit!(mach);
julia> classes_est = predict(mach, X)
150-element CategoricalDistributions.UnivariateFiniteVector{Multiclass{3}, String, UInt8, Float64}:
UnivariateFinite{Multiclass{3}}(setosa=>0.575, versicolor=>0.213, virginica=>0.213)
UnivariateFinite{Multiclass{3}}(setosa=>0.573, versicolor=>0.213, virginica=>0.213)
⋮
UnivariateFinite{Multiclass{3}}(setosa=>0.236, versicolor=>0.236, virginica=>0.529)
UnivariateFinite{Multiclass{3}}(setosa=>0.254, versicolor=>0.254, virginica=>0.492)