KNNClassifier

KNNClassifier

A model type for constructing a K-nearest neighbor classifier, based on NearestNeighborModels.jl, and implementing the MLJ model interface.

From MLJ, the type can be imported using

KNNClassifier = @load KNNClassifier pkg=NearestNeighborModels

Do model = KNNClassifier() to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in KNNClassifier(K=...).

KNNClassifier implements K-Nearest Neighbors classifier which is non-parametric algorithm that predicts a discrete class distribution associated with a new point by taking a vote over the classes of the k-nearest points. Each neighbor vote is assigned a weight based on proximity of the neighbor point to the test point according to a specified distance metric.

For more information about the weighting kernels, see the paper by Geler et.al Comparison of different weighting schemes for the kNN classifier on time-series data.

Training data

In MLJ or MLJBase, bind an instance model to data with

mach = machine(model, X, y)

OR

mach = machine(model, X, y, w)

Here:

  • X is any table of input features (eg, a DataFrame) whose columns are of scitype Continuous; check column scitypes with schema(X).
  • y is the target, which can be any AbstractVector whose element scitype is <:Finite (<:Multiclass or <:OrderedFactor will do); check the scitype with scitype(y)
  • w is the observation weights which can either be nothing (default) or an AbstractVector whose element scitype is Count or Continuous. This is different from weights kernel which is a model hyperparameter, see below.

Train the machine using fit!(mach, rows=...).

Hyper-parameters

  • K::Int=5 : number of neighbors
  • algorithm::Symbol = :kdtree : one of (:kdtree, :brutetree, :balltree)
  • metric::Metric = Euclidean() : any Metric from Distances.jl for the distance between points. For algorithm = :kdtree only metrics which are instances of Distances.UnionMinkowskiMetric are supported.
  • leafsize::Int = algorithm == 10 : determines the number of points at which to stop splitting the tree. This option is ignored and always taken as 0 for algorithm = :brutetree, since brutetree isn't actually a tree.
  • reorder::Bool = true : if true then points which are close in distance are placed close in memory. In this case, a copy of the original data will be made so that the original data is left unmodified. Setting this to true can significantly improve performance of the specified algorithm (except :brutetree). This option is ignored and always taken as false for algorithm = :brutetree.
  • weights::KNNKernel=Uniform() : kernel used in assigning weights to the k-nearest neighbors for each observation. An instance of one of the types in list_kernels(). User-defined weighting functions can be passed by wrapping the function in a UserDefinedKernel kernel (do ?NearestNeighborModels.UserDefinedKernel for more info). If observation weights w are passed during machine construction then the weight assigned to each neighbor vote is the product of the kernel generated weight for that neighbor and the corresponding observation weight.

Operations

  • predict(mach, Xnew): Return predictions of the target given features Xnew, which should have same scitype as X above. Predictions are probabilistic but uncalibrated.
  • predict_mode(mach, Xnew): Return the modes of the probabilistic predictions returned above.

Fitted parameters

The fields of fitted_params(mach) are:

  • tree: An instance of either KDTree, BruteTree or BallTree depending on the value of the algorithm hyperparameter (See hyper-parameters section above). These are data structures that stores the training data with the view of making quicker nearest neighbor searches on test data points.

Examples

using MLJ
KNNClassifier = @load KNNClassifier pkg=NearestNeighborModels
X, y = @load_crabs; ## a table and a vector from the crabs dataset
## view possible kernels
NearestNeighborModels.list_kernels()
## KNNClassifier instantiation
model = KNNClassifier(weights = NearestNeighborModels.Inverse())
mach = machine(model, X, y) |> fit! ## wrap model and required data in an MLJ machine and fit
y_hat = predict(mach, X)
labels = predict_mode(mach, X)

See also MultitargetKNNClassifier