PegasosClassifier
mutable struct PegasosClassifier <: MLJModelInterface.Probabilistic
The gradient-based linear "pegasos" classifier using one-vs-all for multiclass, from the Beta Machine Learning Toolkit (BetaML).
Hyperparameters:
initial_coefficients::Union{Nothing, Matrix{Float64}}
: N-classes by D-dimensions matrix of initial linear coefficients [def:nothing
, i.e. zeros]initial_constant::Union{Nothing, Vector{Float64}}
: N-classes vector of initial contant terms [def:nothing
, i.e. zeros]learning_rate::Function
: Learning rate [def: (epoch -> 1/sqrt(epoch))]learning_rate_multiplicative::Float64
: Multiplicative term of the learning rate [def:0.5
]epochs::Int64
: Maximum number of epochs, i.e. passages trough the whole training sample [def:1000
]shuffle::Bool
: Whether to randomly shuffle the data at each iteration (epoch) [def:true
]force_origin::Bool
: Whether to force the parameter associated with the constant term to remain zero [def:false
]return_mean_hyperplane::Bool
: Whether to return the average hyperplane coefficients instead of the final ones [def:false
]rng::Random.AbstractRNG
: A Random Number Generator to be used in stochastic parts of the code [deafult:Random.GLOBAL_RNG
]
Example:
julia> using MLJ
julia> X, y = @load_iris;
julia> modelType = @load PegasosClassifier pkg = "BetaML" verbosity=0
BetaML.Perceptron.PegasosClassifier
julia> model = modelType()
PegasosClassifier(
initial_coefficients = nothing,
initial_constant = nothing,
learning_rate = BetaML.Perceptron.var"#71#73"(),
learning_rate_multiplicative = 0.5,
epochs = 1000,
shuffle = true,
force_origin = false,
return_mean_hyperplane = false,
rng = Random._GLOBAL_RNG())
julia> mach = machine(model, X, y);
julia> fit!(mach);
julia> est_classes = predict(mach, X)
150-element CategoricalDistributions.UnivariateFiniteVector{Multiclass{3}, String, UInt8, Float64}:
UnivariateFinite{Multiclass{3}}(setosa=>0.817, versicolor=>0.153, virginica=>0.0301)
UnivariateFinite{Multiclass{3}}(setosa=>0.791, versicolor=>0.177, virginica=>0.0318)
⋮
UnivariateFinite{Multiclass{3}}(setosa=>0.254, versicolor=>0.5, virginica=>0.246)
UnivariateFinite{Multiclass{3}}(setosa=>0.283, versicolor=>0.51, virginica=>0.207)