PegasosClassifier
mutable struct PegasosClassifier <: MLJModelInterface.ProbabilisticThe gradient-based linear "pegasos" classifier using one-vs-all for multiclass, from the Beta Machine Learning Toolkit (BetaML).
Hyperparameters:
initial_coefficients::Union{Nothing, Matrix{Float64}}: N-classes by D-dimensions matrix of initial linear coefficients [def:nothing, i.e. zeros]initial_constant::Union{Nothing, Vector{Float64}}: N-classes vector of initial contant terms [def:nothing, i.e. zeros]learning_rate::Function: Learning rate [def: (epoch -> 1/sqrt(epoch))]learning_rate_multiplicative::Float64: Multiplicative term of the learning rate [def:0.5]epochs::Int64: Maximum number of epochs, i.e. passages trough the whole training sample [def:1000]shuffle::Bool: Whether to randomly shuffle the data at each iteration (epoch) [def:true]force_origin::Bool: Whether to force the parameter associated with the constant term to remain zero [def:false]return_mean_hyperplane::Bool: Whether to return the average hyperplane coefficients instead of the final ones [def:false]rng::Random.AbstractRNG: A Random Number Generator to be used in stochastic parts of the code [deafult:Random.GLOBAL_RNG]
Example:
julia> using MLJ
julia> X, y = @load_iris;
julia> modelType = @load PegasosClassifier pkg = "BetaML" verbosity=0
BetaML.Perceptron.PegasosClassifier
julia> model = modelType()
PegasosClassifier(
initial_coefficients = nothing,
initial_constant = nothing,
learning_rate = BetaML.Perceptron.var"#71#73"(),
learning_rate_multiplicative = 0.5,
epochs = 1000,
shuffle = true,
force_origin = false,
return_mean_hyperplane = false,
rng = Random._GLOBAL_RNG())
julia> mach = machine(model, X, y);
julia> fit!(mach);
julia> est_classes = predict(mach, X)
150-element CategoricalDistributions.UnivariateFiniteVector{Multiclass{3}, String, UInt8, Float64}:
UnivariateFinite{Multiclass{3}}(setosa=>0.817, versicolor=>0.153, virginica=>0.0301)
UnivariateFinite{Multiclass{3}}(setosa=>0.791, versicolor=>0.177, virginica=>0.0318)
⋮
UnivariateFinite{Multiclass{3}}(setosa=>0.254, versicolor=>0.5, virginica=>0.246)
UnivariateFinite{Multiclass{3}}(setosa=>0.283, versicolor=>0.51, virginica=>0.207)