PegasosClassifier

mutable struct PegasosClassifier <: MLJModelInterface.Probabilistic

The gradient-based linear "pegasos" classifier using one-vs-all for multiclass, from the Beta Machine Learning Toolkit (BetaML).

Hyperparameters:

  • initial_coefficients::Union{Nothing, Matrix{Float64}}: N-classes by D-dimensions matrix of initial linear coefficients [def: nothing, i.e. zeros]
  • initial_constant::Union{Nothing, Vector{Float64}}: N-classes vector of initial contant terms [def: nothing, i.e. zeros]
  • learning_rate::Function: Learning rate [def: (epoch -> 1/sqrt(epoch))]
  • learning_rate_multiplicative::Float64: Multiplicative term of the learning rate [def: 0.5]
  • epochs::Int64: Maximum number of epochs, i.e. passages trough the whole training sample [def: 1000]
  • shuffle::Bool: Whether to randomly shuffle the data at each iteration (epoch) [def: true]
  • force_origin::Bool: Whether to force the parameter associated with the constant term to remain zero [def: false]
  • return_mean_hyperplane::Bool: Whether to return the average hyperplane coefficients instead of the final ones [def: false]
  • rng::Random.AbstractRNG: A Random Number Generator to be used in stochastic parts of the code [deafult: Random.GLOBAL_RNG]

Example:

julia> using MLJ

julia> X, y        = @load_iris;

julia> modelType   = @load PegasosClassifier pkg = "BetaML" verbosity=0
BetaML.Perceptron.PegasosClassifier

julia> model       = modelType()
PegasosClassifier(
  initial_coefficients = nothing, 
  initial_constant = nothing, 
  learning_rate = BetaML.Perceptron.var"#71#73"(), 
  learning_rate_multiplicative = 0.5, 
  epochs = 1000, 
  shuffle = true, 
  force_origin = false, 
  return_mean_hyperplane = false, 
  rng = Random._GLOBAL_RNG())

julia> mach        = machine(model, X, y);

julia> fit!(mach);

julia> est_classes = predict(mach, X)
150-element CategoricalDistributions.UnivariateFiniteVector{Multiclass{3}, String, UInt8, Float64}:
 UnivariateFinite{Multiclass{3}}(setosa=>0.817, versicolor=>0.153, virginica=>0.0301)
 UnivariateFinite{Multiclass{3}}(setosa=>0.791, versicolor=>0.177, virginica=>0.0318)
 ⋮
 UnivariateFinite{Multiclass{3}}(setosa=>0.254, versicolor=>0.5, virginica=>0.246)
 UnivariateFinite{Multiclass{3}}(setosa=>0.283, versicolor=>0.51, virginica=>0.207)