MultitargetNeuralNetworkRegressor
mutable struct MultitargetNeuralNetworkRegressor <: MLJModelInterface.Deterministic
A simple but flexible Feedforward Neural Network, from the Beta Machine Learning Toolkit (BetaML) for regression of multiple dimensional targets.
Parameters:
layers
: Array of layer objects [def:nothing
, i.e. basic network]. Seesubtypes(BetaML.AbstractLayer)
for supported layersloss
: Loss (cost) function [def:BetaML.squared_cost
]. Should always assume y and ŷ as matrices.Warning If you change the parameter
loss
, you need to either provide its derivative on the parameterdloss
or use autodiff withdloss=nothing
.dloss
: Derivative of the loss function [def:BetaML.dsquared_cost
, i.e. use the derivative of the squared cost]. Usenothing
for autodiff.epochs
: Number of epochs, i.e. passages trough the whole training sample [def:300
]batch_size
: Size of each individual batch [def:16
]opt_alg
: The optimisation algorithm to update the gradient at each batch [def:BetaML.ADAM()
]. Seesubtypes(BetaML.OptimisationAlgorithm)
for supported optimizersshuffle
: Whether to randomly shuffle the data at each iteration (epoch) [def:true
]descr
: An optional title and/or description for this modelcb
: A call back function to provide information during training [def:BetaML.fitting_info
]rng
: Random Number Generator (seeFIXEDSEED
) [deafult:Random.GLOBAL_RNG
]
Notes:
- data must be numerical
- the label should be a n-records by n-dimensions matrix
Example:
julia> using MLJ
julia> X, y = @load_boston;
julia> ydouble = hcat(y, y .*2 .+5);
julia> modelType = @load MultitargetNeuralNetworkRegressor pkg = "BetaML" verbosity=0
BetaML.Nn.MultitargetNeuralNetworkRegressor
julia> layers = [BetaML.DenseLayer(12,50,f=BetaML.relu),BetaML.DenseLayer(50,50,f=BetaML.relu),BetaML.DenseLayer(50,50,f=BetaML.relu),BetaML.DenseLayer(50,2,f=BetaML.relu)];
julia> model = modelType(layers=layers,opt_alg=BetaML.ADAM(),epochs=500)
MultitargetNeuralNetworkRegressor(
layers = BetaML.Nn.AbstractLayer[BetaML.Nn.DenseLayer([-0.2591582523441157 -0.027962845131416225 … 0.16044535560124418 -0.12838827994676857; -0.30381834909561184 0.2405495243851402 … -0.2588144861880588 0.09538577909777807; … ; -0.017320292924711156 -0.14042266424603767 … 0.06366999105841187 -0.13419651752478906; 0.07393079961409338 0.24521350531110264 … 0.04256867886217541 -0.0895506802948175], [0.14249427336553644, 0.24719379413682485, -0.25595911822556566, 0.10034088778965933, -0.017086404878505712, 0.21932184025609347, -0.031413516834861266, -0.12569076082247596, -0.18080140982481183, 0.14551901873323253 … -0.13321995621967364, 0.2436582233332092, 0.0552222336976439, 0.07000814133633904, 0.2280064379660025, -0.28885681475734193, -0.07414214246290696, -0.06783184733650621, -0.055318068046308455, -0.2573488383282579], BetaML.Utils.relu, BetaML.Utils.drelu), BetaML.Nn.DenseLayer([-0.0395424111703751 -0.22531232360829911 … -0.04341228943744482 0.024336206858365517; -0.16481887432946268 0.17798073384748508 … -0.18594039305095766 0.051159225856547474; … ; -0.011639475293705043 -0.02347011206244673 … 0.20508869536159186 -0.1158382446274592; -0.19078069527757857 -0.007487540070740484 … -0.21341165344291158 -0.24158671316310726], [-0.04283623889330032, 0.14924461547060602, -0.17039563392959683, 0.00907774027816255, 0.21738885963113852, -0.06308040225941691, -0.14683286822101105, 0.21726892197970937, 0.19784321784707126, -0.0344988665714947 … -0.23643089430602846, -0.013560425201427584, 0.05323948910726356, -0.04644175812567475, -0.2350400292671211, 0.09628312383424742, 0.07016420995205697, -0.23266392927140334, -0.18823664451487, 0.2304486691429084], BetaML.Utils.relu, BetaML.Utils.drelu), BetaML.Nn.DenseLayer([-0.11504184627266828 0.08601794194664503 … 0.03843129724045469 -0.18417305624127284; 0.10181551438831654 0.13459759904443674 … 0.11094951365942118 -0.1549466590355218; … ; 0.15279817525427697 0.0846661196058916 … -0.07993619892911122 0.07145402617285884; -0.1614160186346092 -0.13032002335149 … -0.12310552194729624 -0.15915773071049827], [-0.03435885900946367, -0.1198543931290306, 0.008454985905194445, -0.17980887188986966, -0.03557204910359624, 0.19125847393334877, -0.10949700778538696, -0.09343206702591, -0.12229583511781811, -0.09123969069220564 … 0.22119233518322862, 0.2053873143308657, 0.12756489387198222, 0.11567243705173319, -0.20982445664020496, 0.1595157838386987, -0.02087331046544119, -0.20556423263489765, -0.1622837764237961, -0.019220998739847395], BetaML.Utils.relu, BetaML.Utils.drelu), BetaML.Nn.DenseLayer([-0.25796717031347993 0.17579536633402948 … -0.09992960168785256 -0.09426177454620635; -0.026436330246675632 0.18070899284865127 … -0.19310119102392206 -0.06904005900252091], [0.16133004882307822, -0.3061228721091248], BetaML.Utils.relu, BetaML.Utils.drelu)],
loss = BetaML.Utils.squared_cost,
dloss = BetaML.Utils.dsquared_cost,
epochs = 500,
batch_size = 32,
opt_alg = BetaML.Nn.ADAM(BetaML.Nn.var"#90#93"(), 1.0, 0.9, 0.999, 1.0e-8, BetaML.Nn.Learnable[], BetaML.Nn.Learnable[]),
shuffle = true,
descr = "",
cb = BetaML.Nn.fitting_info,
rng = Random._GLOBAL_RNG())
julia> mach = machine(model, X, ydouble);
julia> fit!(mach);
julia> ŷdouble = predict(mach, X);
julia> hcat(ydouble,ŷdouble)
506×4 Matrix{Float64}:
24.0 53.0 28.4624 62.8607
21.6 48.2 22.665 49.7401
34.7 74.4 31.5602 67.9433
33.4 71.8 33.0869 72.4337
⋮
23.9 52.8 23.3573 50.654
22.0 49.0 22.1141 48.5926
11.9 28.8 19.9639 45.5823