Nearest Neighbors

Written by Zhiheng Zhang on 11/10/2022.

Introduction

torchml.neighbors currently supports unsupervised learnings on classification problem. It currently supports K Nearest Neighbors classification with torchml.neighbors.NearestNeighbors that implement sklearn.neighbors.NearestNeighbors's brute force solution with torchml.

Probabilistic Derivation

K Nearest Neighbors classification

The principle behind Nearest Neighbors algorithms is, given a distance function and a new test point , the algorithm find k closest samples in the known sample set, and use them to estimate the . The number can be user-defined and tuned according to the particular problem. The distance function can be any arbitrary metric function, and standard Euclidean distance is the most common choice.

One important thing about this algorithm is that its not based on any probabilistic framework, but the algorithm is able to estimate probability for each class given a test point and its k neighbors.

Given a dataset with samples and distinct classes, and a new point we wish to classify:

We calculate the number of samples that fall into a class for all classes:

We first find the nearest neighbors of :

We then count the number of points in the neighbors that are in the class (c$:

The probability that is of class is simply:

This estimation is often accurate in practice, even though the algorithm is not built with probability in mind.

KNN from a bayesian stand point

Even though the KNN algorithm is not built on top of probabilistic framework, we can gain intuition behind its shockingly good estimation by framing it in the bayesian framework.

What we want is:

and in bayesian terms, what we need is:

Given nothing but our samples, , or the prior, is simply

is the probabilistic density of random variable , and we need to borrow some knowledge from density estimation for this analysis:

Since we don't know , we need to conduct discrete trials on . Suppose that the density lies in a D-Dimensional space, and we assume it to be Euclidean. We conduct trials in this space by drawing points on it according to (these points are our samples). By principle of locality, for a given point we've drawn on the space, we can assume that the density have some correlations with points in the small space surrounding it. Let's draw a small sphere around the point, and name the space in the sphere .

The total probability that a test point can end up inside is the sum of probability that a point can be in a point in over all the small points in , or the probability mass of in :

For the samples we gathered, each sample has a probability of being inside , then the total number of points that successfully end up in can be modeled using binomial distribution:

We also have:

For our algorithm we supply the parameter , so we can just sub in our well-chosen instead of the expectation, which gives us:

We further assume that is quite small, thus changes very little inside , and we assume to follow a uniform distribution, then we can derive that:

where is the volume of .

Then our final estimation of will be:

We repeat the process for a specific class , and we will get:

substitute both and into our bayesian, we will get:

Algorithmic Implementation

Given a new sample, the brute-force algorithm is to: 1. Calculate all pairwise distances between the sample point and the labeled examples 2. Find the k neighboring samples with the least k smallest distances 3. For each class, obtain the ratio of number of points in that class in the k neighbors and the number k, and that ratio will be the probability that the new sample belongs to this class.

The torchml Interface

import numpy as np
import torchml as ml
samples = np.array([[0], [1], [2], [3]])
y = np.array([0, 0, 1, 1])
point = np.array([1.1])
neigh = ml.neighbors.KNeighborsClassifier(n_neighbors=3)
neigh.fit(torch.from_numpy(samples), torch.from_numpy(y))
neigh.predict(torch.from_numpy(point)) # returns the most likely class label
neigh.predict_proba(torch.from_numpy(point)) # returns all the class probabilities 

References