Skip to content

Instantly share code, notes, and snippets.

@charecktowa
Created March 16, 2023 03:23
Show Gist options
  • Select an option

  • Save charecktowa/4b337dd14d8a317ec8dd337ff34cac14 to your computer and use it in GitHub Desktop.

Select an option

Save charecktowa/4b337dd14d8a317ec8dd337ff34cac14 to your computer and use it in GitHub Desktop.
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
public class KNearestNeighbor {
private int k;
private double[][] datos;
private int[] etiquetas;
public KNearestNeighbor(int k) {
this.k = k;
}
public KNearestNeighbor(double[][] X, int[] y, int k) {
this.datos = X;
this.etiquetas = y;
this.k = k;
}
/**
* El siguiente métdo se debería usar en caso de que se
* use el primer constructor, en el que el clasificador no se
* inicializa con los valores de entrenamiento.
*
* Este método solo lo creo para tener cierto parecido
* con "Scikit Learn" que es una librería de aprendizaje
* máquina de Python.
*
* @param X valores (caracteristicas)
* @param y etiquetas
*/
public void fit(double[][] X, int[] y) {
this.datos = X;
this.etiquetas = y;
}
/**
* Método que se encanrga de hacer el proceso del KNN
*
* @param X datos (caracteristicas) a clasificar / predecir
* @return un arreglo de valores enteros que representa las etiquetas de clase
*/
public int[] predecir(double[][] X) {
/*
* Primero calculamos las distancias de nuestro por dato
*
* En el KNN se sigue el siguiente proceso:
* 1. Calcular las distancias del nuevo dato respecto a los datos conocidos
* 2. Se ordenan las distancias de menor a mayor
* 3. Se toman los datos de las k distancias más cercanas
* 4. Se hace una "votación" para ver cuál es la etiqueta que más se repite
* 5. Se regresa la predicción
*/
int[] etiquetasPredecidas = new int[X.length];
for (int i = 0; i < X.length; i++) {
var distancias = new HashMap<Double, Integer>();
/* Calculamos distancias del punto actual respecto a todo el conjunto */
for (int j = 0; j < datos.length; j++) {
double distancia = distanciaEuclidea(X[i], datos[j]);
distancias.put(distancia, etiquetas[j]);
}
/*
* Ahora se ordenan las distancias de menor a mayor
* ThreeMap ayuda a ordenar de menor a mayor un HashMap.
* Se pueden usar otras técnicas para ordenar pero creo sería más
* complejo, aparte es sobre el KNN.
*/
var distanciasOrdenadas = new TreeMap<Double, Integer>(distancias);
/* Después únicamente tomamos las k distancias más cercanas */
var kDistanciasCercanas = obtenerKDistancias(distanciasOrdenadas);
etiquetasPredecidas[i] = masRepetido(kDistanciasCercanas);
}
return etiquetasPredecidas;
}
private double distanciaEuclidea(double[] x, double[] y) {
double distancia = 0.0;
for (int i = 0; i < x.length; i++) {
distancia += Math.pow((x[i] - y[i]), 2);
}
return Math.sqrt(distancia);
}
private int masRepetido(TreeMap<Double, Integer> elemento) {
var etiquetass = obtenerValores(elemento);
return etiquetass.stream()
.collect(Collectors.groupingBy(Function.identity(), Collectors.counting()))
.entrySet().stream().max((o1, o2) -> o1.getValue().compareTo(o2.getValue()))
.map(Map.Entry::getKey).orElse(null);
}
/**
* @param elemento
* @return
*/
private List<Integer> obtenerValores(TreeMap<Double, Integer> elemento) {
var etiquetass = new ArrayList<Integer>();
for (Map.Entry<Double, Integer> e : elemento.entrySet()) {
etiquetass.add(e.getValue());
}
return etiquetass;
}
private TreeMap<Double, Integer> obtenerKDistancias(SortedMap<Double, Integer> distancias) {
int n = 0;
var dist = new TreeMap<Double, Integer>();
for (Map.Entry<Double, Integer> e : distancias.entrySet()) {
if (n >= k)
break;
dist.put(e.getKey(), e.getValue());
n++;
}
return dist;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment