
mutable struct NeuralNetworkClassifier <: MLJModelInterface.Probabilistic

A simple but flexible Feedforward Neural Network, from the Beta Machine Learning Toolkit (BetaML) for classification problems.


  • layers: Array of layer objects [def: nothing, i.e. basic network]. See subtypes(BetaML.AbstractLayer) for supported layers. The last "softmax" layer is automatically added.

  • loss: Loss (cost) function [def: BetaML.crossentropy]. Should always assume y and ŷ as matrices.


    If you change the parameter loss, you need to either provide its derivative on the parameter dloss or use autodiff with dloss=nothing.

  • dloss: Derivative of the loss function [def: BetaML.dcrossentropy, i.e. the derivative of the cross-entropy]. Use nothing for autodiff.

  • epochs: Number of epochs, i.e. passages trough the whole training sample [def: 200]

  • batch_size: Size of each individual batch [def: 16]

  • opt_alg: The optimisation algorithm to update the gradient at each batch [def: BetaML.ADAM()]. See subtypes(BetaML.OptimisationAlgorithm) for supported optimizers

  • shuffle: Whether to randomly shuffle the data at each iteration (epoch) [def: true]

  • descr: An optional title and/or description for this model

  • cb: A call back function to provide information during training [def: BetaML.fitting_info]

  • categories: The categories to represent as columns. [def: nothing, i.e. unique training values].

  • handle_unknown: How to handle categories not seens in training or not present in the provided categories array? "error" (default) rises an error, "infrequent" adds a specific column for these categories.

  • other_categories_name: Which value during prediction to assign to this "other" category (i.e. categories not seen on training or not present in the provided categories array? [def: nothing, i.e. typemax(Int64) for integer vectors and "other" for other types]. This setting is active only if handle_unknown="infrequent" and in that case it MUST be specified if Y is neither integer or strings

  • rng: Random Number Generator [deafult: Random.GLOBAL_RNG]


  • data must be numerical
  • the label should be a n-records by n-dimensions matrix (e.g. a one-hot-encoded data for classification), where the output columns should be interpreted as the probabilities for each categories.


julia> using MLJ

julia> X, y        = @load_iris;

julia> modelType   = @load NeuralNetworkClassifier pkg = "BetaML" verbosity=0

julia> layers      = [BetaML.DenseLayer(4,8,f=BetaML.relu),BetaML.DenseLayer(8,8,f=BetaML.relu),BetaML.DenseLayer(8,3,f=BetaML.relu),BetaML.VectorFunctionLayer(3,f=BetaML.softmax)];

julia> model       = modelType(layers=layers,opt_alg=BetaML.ADAM())
  layers = BetaML.Nn.AbstractLayer[BetaML.Nn.DenseLayer([-0.376173352338049 0.7029289511758696 -0.5589563304592478 -0.21043274001651874; 0.044758889527899415 0.6687689636685921 0.4584331114653877 0.6820506583840453; … ; -0.26546358457167507 -0.28469736227283804 -0.164225549922154 -0.516785639164486; -0.5146043550684141 -0.0699113265130964 0.14959906603941908 -0.053706860039406834], [0.7003943613125758, -0.23990840466587576, -0.23823126271387746, 0.4018101580410387, 0.2274483050356888, -0.564975060667734, 0.1732063297031089, 0.11880299829896945], BetaML.Utils.relu, BetaML.Utils.drelu), BetaML.Nn.DenseLayer([-0.029467850439546583 0.4074661266592745 … 0.36775675246760053 -0.595524555448422; 0.42455597698371306 -0.2458082732997091 … -0.3324220683462514 0.44439454998610595; … ; -0.2890883863364267 -0.10109249362508033 … -0.0602680568207582 0.18177278845097555; -0.03432587226449335 -0.4301192922760063 … 0.5646018168286626 0.47269177680892693], [0.13777442835428688, 0.5473306726675433, 0.3781939472904011, 0.24021813428130567, -0.0714779477402877, -0.020386373530818958, 0.5465466618404464, -0.40339790713616525], BetaML.Utils.relu, BetaML.Utils.drelu), BetaML.Nn.DenseLayer([0.6565120540082393 0.7139211611842745 … 0.07809812467915389 -0.49346311403373844; -0.4544472987041656 0.6502667641568863 … 0.43634608676548214 0.7213049952968921; 0.41212264783075303 -0.21993289366360613 … 0.25365007887755064 -0.5664469566269569], [-0.6911986792747682, -0.2149343209329364, -0.6347727539063817], BetaML.Utils.relu, BetaML.Utils.drelu), BetaML.Nn.VectorFunctionLayer{0}(fill(NaN), 3, 3, BetaML.Utils.softmax, BetaML.Utils.dsoftmax, nothing)], 
  loss = BetaML.Utils.crossentropy, 
  dloss = BetaML.Utils.dcrossentropy, 
  epochs = 100, 
  batch_size = 32, 
  opt_alg = BetaML.Nn.ADAM(BetaML.Nn.var"#90#93"(), 1.0, 0.9, 0.999, 1.0e-8, BetaML.Nn.Learnable[], BetaML.Nn.Learnable[]), 
  shuffle = true, 
  descr = "", 
  cb = BetaML.Nn.fitting_info, 
  categories = nothing, 
  handle_unknown = "error", 
  other_categories_name = nothing, 
  rng = Random._GLOBAL_RNG())

julia> mach        = machine(model, X, y);

julia> fit!(mach);

julia> classes_est = predict(mach, X)
150-element CategoricalDistributions.UnivariateFiniteVector{Multiclass{3}, String, UInt8, Float64}:
 UnivariateFinite{Multiclass{3}}(setosa=>0.575, versicolor=>0.213, virginica=>0.213)
 UnivariateFinite{Multiclass{3}}(setosa=>0.573, versicolor=>0.213, virginica=>0.213)
 UnivariateFinite{Multiclass{3}}(setosa=>0.236, versicolor=>0.236, virginica=>0.529)
 UnivariateFinite{Multiclass{3}}(setosa=>0.254, versicolor=>0.254, virginica=>0.492)