Created
March 25, 2023 18:18
-
-
Save charecktowa/70c67c2e8226651c6905616e07baa2f2 to your computer and use it in GitHub Desktop.
Código de KMeans en Java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package clustering; | |
| import java.security.NoSuchAlgorithmException; | |
| import java.security.SecureRandom; | |
| import java.util.Arrays; | |
| import java.util.HashMap; | |
| import java.util.Map; | |
| import java.util.Objects; | |
| import java.util.Random; | |
| import java.util.Set; | |
| import java.util.stream.Collectors; | |
| public class KMeans { | |
| private int k; | |
| private int maxIterations; | |
| private Random random; | |
| public KMeans(int k, int maxIterations) { | |
| this.k = k; | |
| this.maxIterations = maxIterations; | |
| /* Como solo se usa una vez decido ponerlo dentro del constructor */ | |
| try { | |
| this.random = SecureRandom.getInstanceStrong(); | |
| } catch (NoSuchAlgorithmException e) { | |
| e.printStackTrace(); | |
| } | |
| } | |
| public Map<Double[], Integer> predict(double[][] data) { | |
| /* 1. Inicializamos los centroides de manera aleatoria */ | |
| double[][] centroides = incializarCentroides(data); | |
| Map<Double[], Integer> clusters = new HashMap<>(); | |
| double[][] centroidesPrevios = new double[k][data[0].length]; | |
| boolean hasConverged = false; | |
| /* | |
| * Ejecutamos el algoritmo hasta el máximo de iteraciones o hasta que converga | |
| */ | |
| for (int iteraciones = 0; iteraciones < maxIterations && !hasConverged; iteraciones++) { | |
| Map<Double[], Integer> temp = new HashMap<>(); | |
| for (int i = 0; i < data.length; i++) { | |
| double[] distancias = new double[k]; | |
| for (int j = 0; j < k; j++) { | |
| double distancia = distanciaEuclidea(data[i], centroides[j]); | |
| distancias[j] = distancia; | |
| } | |
| int indice = encontrarIndiceMenor(distancias); | |
| /* Aquí hacemos un casteo, porque Java... */ | |
| Double[] dataDoubles = Arrays.stream(data[i]).boxed().toArray(Double[]::new); | |
| temp.put(dataDoubles, Integer.valueOf(indice)); | |
| } | |
| clusters = temp; | |
| /* Actualizamos los centroides */ | |
| double[][] centroidesActualizados = actualizarCentroides(clusters, data); | |
| hasConverged = Arrays.deepEquals(centroidesActualizados, centroidesPrevios); | |
| centroidesPrevios = centroides; | |
| centroides = centroidesActualizados; | |
| } | |
| return clusters; | |
| } | |
| private double[][] incializarCentroides(double[][] data) { | |
| /* Elegimos k elementos de forma aleatoria */ | |
| Set<Integer> randomIndex = this.random.ints(0, data.length) | |
| .distinct() | |
| .limit(k) | |
| .boxed() | |
| .collect(Collectors.toSet()); | |
| double[][] centroides = new double[k][data[0].length]; | |
| int i = 0; | |
| for (Integer idx : randomIndex) { | |
| centroides[i] = data[idx]; | |
| i++; | |
| } | |
| return centroides; | |
| } | |
| private double[][] actualizarCentroides(Map<Double[], Integer> clusters, double[][] x) { | |
| double[][] centroides = new double[k][x[0].length]; | |
| for (int i = 0; i < k; i++) { | |
| int clusterSize = 0; | |
| Arrays.fill(centroides[i], 0); | |
| for (Map.Entry<Double[], Integer> entry : clusters.entrySet()) { | |
| if (Objects.equals(entry.getValue(), i)) { | |
| double[] point = Arrays.stream(entry.getKey()).mapToDouble(Double::doubleValue).toArray(); | |
| for (int j = 0; j < x[0].length; j++) { | |
| centroides[i][j] += point[j]; | |
| } | |
| clusterSize++; | |
| } | |
| } | |
| if (clusterSize > 0) { | |
| for (int j = 0; j < x[0].length; j++) { | |
| centroides[i][j] /= clusterSize; | |
| } | |
| } | |
| } | |
| return centroides; | |
| } | |
| private double distanciaEuclidea(double[] x, double[] y) { | |
| double distancia = 0.0; | |
| for (int i = 0; i < x.length; i++) { | |
| distancia += Math.pow((x[i] - y[i]), 2); | |
| } | |
| return Math.sqrt(distancia); | |
| } | |
| private static int encontrarIndiceMenor(double[] array) { | |
| double menor = array[0]; | |
| int indiceMenor = 0; | |
| for (int i = 0; i < array.length; i++) { | |
| if (array[i] < menor) { | |
| menor = array[i]; | |
| indiceMenor = i; | |
| } | |
| } | |
| return indiceMenor; | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment