MultitargetKNNClassifier
MultitargetKNNClassifierA model type for constructing a multitarget K-nearest neighbor classifier, based on NearestNeighborModels.jl, and implementing the MLJ model interface.
From MLJ, the type can be imported using
MultitargetKNNClassifier = @load MultitargetKNNClassifier pkg=NearestNeighborModelsDo model = MultitargetKNNClassifier() to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in MultitargetKNNClassifier(K=...).
Multi-target K-Nearest Neighbors Classifier (MultitargetKNNClassifier) is a variation of KNNClassifier that assumes the target variable is vector-valued with Multiclass or OrderedFactor components. (Target data must be presented as a table, however.)
Training data
In MLJ or MLJBase, bind an instance model to data with
mach = machine(model, X, y)OR
mach = machine(model, X, y, w)Here:
Xis any table of input features (eg, aDataFrame) whose columns are of scitypeContinuous; check column scitypes withschema(X).- y
is the target, which can be any table of responses whose element scitype is either<:Finite(<:Multiclassor<:OrderedFactorwill do); check the columns scitypes withschema(y). Each column ofy` is assumed to belong to a common categorical pool. wis the observation weights which can either benothing(default) or anAbstractVectorwhose element scitype isCountorContinuous. This is different fromweightskernel which is a model hyperparameter, see below.
Train the machine using fit!(mach, rows=...).
Hyper-parameters
K::Int=5: number of neighborsalgorithm::Symbol = :kdtree: one of(:kdtree, :brutetree, :balltree)metric::Metric = Euclidean(): anyMetricfrom Distances.jl for the distance between points. Foralgorithm = :kdtreeonly metrics which are instances ofDistances.UnionMinkowskiMetricare supported.leafsize::Int = algorithm == 10: determines the number of points at which to stop splitting the tree. This option is ignored and always taken as0foralgorithm = :brutetree, sincebrutetreeisn't actually a tree.reorder::Bool = true: iftruethen points which are close in distance are placed close in memory. In this case, a copy of the original data will be made so that the original data is left unmodified. Setting this totruecan significantly improve performance of the specifiedalgorithm(except:brutetree). This option is ignored and always taken asfalseforalgorithm = :brutetree.weights::KNNKernel=Uniform(): kernel used in assigning weights to the k-nearest neighbors for each observation. An instance of one of the types inlist_kernels(). User-defined weighting functions can be passed by wrapping the function in aUserDefinedKernelkernel (do?NearestNeighborModels.UserDefinedKernelfor more info). If observation weightsware passed during machine construction then the weight assigned to each neighbor vote is the product of the kernel generated weight for that neighbor and the corresponding observation weight.output_type::Type{<:MultiUnivariateFinite}=DictTable: One of (ColumnTable,DictTable). The type of table type to use for predictions. Setting toColumnTablemight improve performance for narrow tables while setting toDictTableimproves performance for wide tables.
Operations
predict(mach, Xnew): Return predictions of the target given featuresXnew, which should have same scitype asXabove. Predictions are either aColumnTableorDictTableofUnivariateFiniteVectorcolumns depending on the value set for theoutput_typeparameter discussed above. The probabilistic predictions are uncalibrated.predict_mode(mach, Xnew): Return the modes of each column of the table of probabilistic predictions returned above.
Fitted parameters
The fields of fitted_params(mach) are:
tree: An instance of eitherKDTree,BruteTreeorBallTreedepending on the value of thealgorithmhyperparameter (See hyper-parameters section above). These are data structures that stores the training data with the view of making quicker nearest neighbor searches on test data points.
Examples
using MLJ, StableRNGs
## set rng for reproducibility
rng = StableRNG(10)
## Dataset generation
n, p = 10, 3
X = table(randn(rng, n, p)) ## feature table
fruit, color = categorical(["apple", "orange"]), categorical(["blue", "green"])
y = [(fruit = rand(rng, fruit), color = rand(rng, color)) for _ in 1:n] ## target_table
## Each column in y has a common categorical pool as expected
selectcols(y, :fruit) ## categorical array
selectcols(y, :color) ## categorical array
## Load MultitargetKNNClassifier
MultitargetKNNClassifier = @load MultitargetKNNClassifier pkg=NearestNeighborModels
## view possible kernels
NearestNeighborModels.list_kernels()
## MultitargetKNNClassifier instantiation
model = MultitargetKNNClassifier(K=3, weights = NearestNeighborModels.Inverse())
## wrap model and required data in an MLJ machine and fit
mach = machine(model, X, y) |> fit!
## predict
y_hat = predict(mach, X)
labels = predict_mode(mach, X)
See also KNNClassifier