Created
March 16, 2023 03:23
-
-
Save charecktowa/4b337dd14d8a317ec8dd337ff34cac14 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import java.util.ArrayList; | |
| import java.util.HashMap; | |
| import java.util.List; | |
| import java.util.Map; | |
| import java.util.SortedMap; | |
| import java.util.TreeMap; | |
| import java.util.function.Function; | |
| import java.util.stream.Collectors; | |
| public class KNearestNeighbor { | |
| private int k; | |
| private double[][] datos; | |
| private int[] etiquetas; | |
| public KNearestNeighbor(int k) { | |
| this.k = k; | |
| } | |
| public KNearestNeighbor(double[][] X, int[] y, int k) { | |
| this.datos = X; | |
| this.etiquetas = y; | |
| this.k = k; | |
| } | |
| /** | |
| * El siguiente métdo se debería usar en caso de que se | |
| * use el primer constructor, en el que el clasificador no se | |
| * inicializa con los valores de entrenamiento. | |
| * | |
| * Este método solo lo creo para tener cierto parecido | |
| * con "Scikit Learn" que es una librería de aprendizaje | |
| * máquina de Python. | |
| * | |
| * @param X valores (caracteristicas) | |
| * @param y etiquetas | |
| */ | |
| public void fit(double[][] X, int[] y) { | |
| this.datos = X; | |
| this.etiquetas = y; | |
| } | |
| /** | |
| * Método que se encanrga de hacer el proceso del KNN | |
| * | |
| * @param X datos (caracteristicas) a clasificar / predecir | |
| * @return un arreglo de valores enteros que representa las etiquetas de clase | |
| */ | |
| public int[] predecir(double[][] X) { | |
| /* | |
| * Primero calculamos las distancias de nuestro por dato | |
| * | |
| * En el KNN se sigue el siguiente proceso: | |
| * 1. Calcular las distancias del nuevo dato respecto a los datos conocidos | |
| * 2. Se ordenan las distancias de menor a mayor | |
| * 3. Se toman los datos de las k distancias más cercanas | |
| * 4. Se hace una "votación" para ver cuál es la etiqueta que más se repite | |
| * 5. Se regresa la predicción | |
| */ | |
| int[] etiquetasPredecidas = new int[X.length]; | |
| for (int i = 0; i < X.length; i++) { | |
| var distancias = new HashMap<Double, Integer>(); | |
| /* Calculamos distancias del punto actual respecto a todo el conjunto */ | |
| for (int j = 0; j < datos.length; j++) { | |
| double distancia = distanciaEuclidea(X[i], datos[j]); | |
| distancias.put(distancia, etiquetas[j]); | |
| } | |
| /* | |
| * Ahora se ordenan las distancias de menor a mayor | |
| * ThreeMap ayuda a ordenar de menor a mayor un HashMap. | |
| * Se pueden usar otras técnicas para ordenar pero creo sería más | |
| * complejo, aparte es sobre el KNN. | |
| */ | |
| var distanciasOrdenadas = new TreeMap<Double, Integer>(distancias); | |
| /* Después únicamente tomamos las k distancias más cercanas */ | |
| var kDistanciasCercanas = obtenerKDistancias(distanciasOrdenadas); | |
| etiquetasPredecidas[i] = masRepetido(kDistanciasCercanas); | |
| } | |
| return etiquetasPredecidas; | |
| } | |
| private double distanciaEuclidea(double[] x, double[] y) { | |
| double distancia = 0.0; | |
| for (int i = 0; i < x.length; i++) { | |
| distancia += Math.pow((x[i] - y[i]), 2); | |
| } | |
| return Math.sqrt(distancia); | |
| } | |
| private int masRepetido(TreeMap<Double, Integer> elemento) { | |
| var etiquetass = obtenerValores(elemento); | |
| return etiquetass.stream() | |
| .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())) | |
| .entrySet().stream().max((o1, o2) -> o1.getValue().compareTo(o2.getValue())) | |
| .map(Map.Entry::getKey).orElse(null); | |
| } | |
| /** | |
| * @param elemento | |
| * @return | |
| */ | |
| private List<Integer> obtenerValores(TreeMap<Double, Integer> elemento) { | |
| var etiquetass = new ArrayList<Integer>(); | |
| for (Map.Entry<Double, Integer> e : elemento.entrySet()) { | |
| etiquetass.add(e.getValue()); | |
| } | |
| return etiquetass; | |
| } | |
| private TreeMap<Double, Integer> obtenerKDistancias(SortedMap<Double, Integer> distancias) { | |
| int n = 0; | |
| var dist = new TreeMap<Double, Integer>(); | |
| for (Map.Entry<Double, Integer> e : distancias.entrySet()) { | |
| if (n >= k) | |
| break; | |
| dist.put(e.getKey(), e.getValue()); | |
| n++; | |
| } | |
| return dist; | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment