Lab 9 - Support Vector Machine
To ensure code in this tutorial runs as shown, download the tutorial project folder and follow these instructions.If you have questions or suggestions about this tutorial, please open an issue here.
using MLJ
import RDatasets: dataset
using PrettyPrinting
using Random
false
We start by generating a 2D cloud of points
Random.seed!(3203)
X = randn(20, 2)
y = vcat(-ones(10), ones(10))
20-element Vector{Float64}:
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
which we can visualise
using Plots
ym1 = y .== -1
ym2 = .!ym1
scatter(X[ym1, 1], X[ym1, 2], markershape=:circle, label="y=-1")
scatter!(X[ym2, 1], X[ym2, 2], markershape=:cross, label="y=1")
plot!(legend=:bottomright, xlabel="X1", ylabel="X2", title="Scatter Plot", size=(800,600))
savefig(joinpath(@OUTPUT, "ISL-lab-9-g1.svg")); # You need to define @OUTPUT
let's wrap the data as a table:
X = MLJ.table(X)
y = categorical(y);
and fit a SVM classifier
SVC = @load SVC pkg=LIBSVM
svc_mdl = SVC()
svc = machine(svc_mdl, X, y)
fit!(svc);
import MLJLIBSVMInterface ✔
As usual we can check how it performs
ypred = MLJ.predict(svc, X)
misclassification_rate(ypred, y)
0.3
Not bad.
As usual we could tune the model, for instance the penalty encoding the tradeoff between margin width and misclassification:
rc = range(svc_mdl, :cost, lower=0.1, upper=5)
tm = TunedModel(model=svc_mdl, ranges=[rc], tuning=Grid(resolution=10),
resampling=CV(nfolds=3, rng=33), measure=misclassification_rate)
mtm = machine(tm, X, y)
fit!(mtm)
ypred = MLJ.predict(mtm, X)
misclassification_rate(ypred, y)
0.2
You could also change the kernel etc.
© Thibaut Lienart, Anthony Blaom, Sebastian Vollmer and collaborators. Last modified: June 24, 2024. Website built with Franklin.jl.