KMeans

KMeans

A model type for constructing a K-means clusterer, based on Clustering.jl, and implementing the MLJ model interface.

From MLJ, the type can be imported using

KMeans = @load KMeans pkg=Clustering

Do model = KMeans() to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in KMeans(k=...).

K-means is a classical method for clustering or vector quantization. It produces a fixed number of clusters, each associated with a center (also known as a prototype), and each data point is assigned to a cluster with the nearest center.

From a mathematical standpoint, K-means is a coordinate descent algorithm that solves the following optimization problem:

:$

\text{minimize} \ \sum{i=1}^n \| \mathbf{x}i - \boldsymbol{\mu}{zi} \|^2 \ \text{w.r.t.} \ (\boldsymbol{\mu}, z) :$

Here, $\boldsymbol{\mu}_k$ is the center of the $k$-th cluster, and $z_i$ is an index of the cluster for $i$-th point $\mathbf{x}_i$.

Training data

In MLJ or MLJBase, bind an instance model to data with

mach = machine(model, X)

Here:

  • X is any table of input features (eg, a DataFrame) whose columns are of scitype Continuous; check column scitypes with schema(X).

Train the machine using fit!(mach, rows=...).

Hyper-parameters

  • k=3: The number of centroids to use in clustering.

  • metric::SemiMetric=Distances.SqEuclidean: The metric used to calculate the clustering. Must have type PreMetric from Distances.jl.

  • init = :kmpp: One of the following options to indicate how cluster seeds should be initialized:

    • :kmpp: KMeans++
    • :kmenc: K-medoids initialization based on centrality
    • :rand: random
    • an instance of Clustering.SeedingAlgorithm from Clustering.jl
    • an integer vector of length k that provides the indices of points to use as initial cluster centers.

    See documentation of Clustering.jl.

Operations

  • predict(mach, Xnew): return cluster label assignments, given new features Xnew having the same Scitype as X above.
  • transform(mach, Xnew): instead return the mean pairwise distances from new samples to the cluster centers.

Fitted parameters

The fields of fitted_params(mach) are:

  • centers: The coordinates of the cluster centers.

Report

The fields of report(mach) are:

  • assignments: The cluster assignments of each point in the training data.
  • cluster_labels: The labels assigned to each cluster.

Examples

using MLJ
KMeans = @load KMeans pkg=Clustering

table = load_iris()
y, X = unpack(table, ==(:target), rng=123)
model = KMeans(k=3)
mach = machine(model, X) |> fit!

yhat = predict(mach, X)
@assert yhat == report(mach).assignments

compare = zip(yhat, y) |> collect;
compare[1:8] ## clusters align with classes

center_dists = transform(mach, fitted_params(mach).centers')

@assert center_dists[1][1] == 0.0
@assert center_dists[2][2] == 0.0
@assert center_dists[3][3] == 0.0

See also KMedoids