KMeans
KMeans
A model type for constructing a K-means clusterer, based on Clustering.jl, and implementing the MLJ model interface.
From MLJ, the type can be imported using
KMeans = @load KMeans pkg=Clustering
Do model = KMeans()
to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in KMeans(k=...)
.
K-means is a classical method for clustering or vector quantization. It produces a fixed number of clusters, each associated with a center (also known as a prototype), and each data point is assigned to a cluster with the nearest center.
From a mathematical standpoint, K-means is a coordinate descent algorithm that solves the following optimization problem:
:$
\text{minimize} \ \sum{i=1}^n \| \mathbf{x}i - \boldsymbol{\mu}{zi} \|^2 \ \text{w.r.t.} \ (\boldsymbol{\mu}, z) :$
Here, $\boldsymbol{\mu}_k$ is the center of the $k$-th cluster, and $z_i$ is an index of the cluster for $i$-th point $\mathbf{x}_i$.
Training data
In MLJ or MLJBase, bind an instance model
to data with
mach = machine(model, X)
Here:
X
is any table of input features (eg, aDataFrame
) whose columns are of scitypeContinuous
; check column scitypes withschema(X)
.
Train the machine using fit!(mach, rows=...)
.
Hyper-parameters
k=3
: The number of centroids to use in clustering.metric::SemiMetric=Distances.SqEuclidean
: The metric used to calculate the clustering. Must have typePreMetric
from Distances.jl.init = :kmpp
: One of the following options to indicate how cluster seeds should be initialized::kmpp
: KMeans++:kmenc
: K-medoids initialization based on centrality:rand
: random- an instance of
Clustering.SeedingAlgorithm
from Clustering.jl - an integer vector of length
k
that provides the indices of points to use as initial cluster centers.
Operations
predict(mach, Xnew)
: return cluster label assignments, given new featuresXnew
having the same Scitype asX
above.transform(mach, Xnew)
: instead return the mean pairwise distances from new samples to the cluster centers.
Fitted parameters
The fields of fitted_params(mach)
are:
centers
: The coordinates of the cluster centers.
Report
The fields of report(mach)
are:
assignments
: The cluster assignments of each point in the training data.cluster_labels
: The labels assigned to each cluster.
Examples
using MLJ
KMeans = @load KMeans pkg=Clustering
table = load_iris()
y, X = unpack(table, ==(:target), rng=123)
model = KMeans(k=3)
mach = machine(model, X) |> fit!
yhat = predict(mach, X)
@assert yhat == report(mach).assignments
compare = zip(yhat, y) |> collect;
compare[1:8] ## clusters align with classes
center_dists = transform(mach, fitted_params(mach).centers')
@assert center_dists[1][1] == 0.0
@assert center_dists[2][2] == 0.0
@assert center_dists[3][3] == 0.0
See also KMedoids