Skip to content

Instantly share code, notes, and snippets.

@charecktowa
Created March 25, 2023 18:18
Show Gist options
  • Select an option

  • Save charecktowa/70c67c2e8226651c6905616e07baa2f2 to your computer and use it in GitHub Desktop.

Select an option

Save charecktowa/70c67c2e8226651c6905616e07baa2f2 to your computer and use it in GitHub Desktop.
Código de KMeans en Java
package clustering;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.Set;
import java.util.stream.Collectors;
public class KMeans {
private int k;
private int maxIterations;
private Random random;
public KMeans(int k, int maxIterations) {
this.k = k;
this.maxIterations = maxIterations;
/* Como solo se usa una vez decido ponerlo dentro del constructor */
try {
this.random = SecureRandom.getInstanceStrong();
} catch (NoSuchAlgorithmException e) {
e.printStackTrace();
}
}
public Map<Double[], Integer> predict(double[][] data) {
/* 1. Inicializamos los centroides de manera aleatoria */
double[][] centroides = incializarCentroides(data);
Map<Double[], Integer> clusters = new HashMap<>();
double[][] centroidesPrevios = new double[k][data[0].length];
boolean hasConverged = false;
/*
* Ejecutamos el algoritmo hasta el máximo de iteraciones o hasta que converga
*/
for (int iteraciones = 0; iteraciones < maxIterations && !hasConverged; iteraciones++) {
Map<Double[], Integer> temp = new HashMap<>();
for (int i = 0; i < data.length; i++) {
double[] distancias = new double[k];
for (int j = 0; j < k; j++) {
double distancia = distanciaEuclidea(data[i], centroides[j]);
distancias[j] = distancia;
}
int indice = encontrarIndiceMenor(distancias);
/* Aquí hacemos un casteo, porque Java... */
Double[] dataDoubles = Arrays.stream(data[i]).boxed().toArray(Double[]::new);
temp.put(dataDoubles, Integer.valueOf(indice));
}
clusters = temp;
/* Actualizamos los centroides */
double[][] centroidesActualizados = actualizarCentroides(clusters, data);
hasConverged = Arrays.deepEquals(centroidesActualizados, centroidesPrevios);
centroidesPrevios = centroides;
centroides = centroidesActualizados;
}
return clusters;
}
private double[][] incializarCentroides(double[][] data) {
/* Elegimos k elementos de forma aleatoria */
Set<Integer> randomIndex = this.random.ints(0, data.length)
.distinct()
.limit(k)
.boxed()
.collect(Collectors.toSet());
double[][] centroides = new double[k][data[0].length];
int i = 0;
for (Integer idx : randomIndex) {
centroides[i] = data[idx];
i++;
}
return centroides;
}
private double[][] actualizarCentroides(Map<Double[], Integer> clusters, double[][] x) {
double[][] centroides = new double[k][x[0].length];
for (int i = 0; i < k; i++) {
int clusterSize = 0;
Arrays.fill(centroides[i], 0);
for (Map.Entry<Double[], Integer> entry : clusters.entrySet()) {
if (Objects.equals(entry.getValue(), i)) {
double[] point = Arrays.stream(entry.getKey()).mapToDouble(Double::doubleValue).toArray();
for (int j = 0; j < x[0].length; j++) {
centroides[i][j] += point[j];
}
clusterSize++;
}
}
if (clusterSize > 0) {
for (int j = 0; j < x[0].length; j++) {
centroides[i][j] /= clusterSize;
}
}
}
return centroides;
}
private double distanciaEuclidea(double[] x, double[] y) {
double distancia = 0.0;
for (int i = 0; i < x.length; i++) {
distancia += Math.pow((x[i] - y[i]), 2);
}
return Math.sqrt(distancia);
}
private static int encontrarIndiceMenor(double[] array) {
double menor = array[0];
int indiceMenor = 0;
for (int i = 0; i < array.length; i++) {
if (array[i] < menor) {
menor = array[i];
indiceMenor = i;
}
}
return indiceMenor;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment