Skip to content

Instantly share code, notes, and snippets.

@AdamSpannbauer
Last active February 26, 2026 02:28
Show Gist options
  • Select an option

  • Save AdamSpannbauer/72af0e090c344b21e9de307fad5380ad to your computer and use it in GitHub Desktop.

Select an option

Save AdamSpannbauer/72af0e090c344b21e9de307fad5380ad to your computer and use it in GitHub Desktop.
Helper to make a conditionally formatted table of kmeans centroids with ggplot2
library(ggplot2)
library(dplyr)
library(tidyr)
plot_centroids_table <- function(kmeans_object) {
n_clusters <- nrow(kmeans_object$centers)
plot_df <- data.frame(t(kmeans_object$centers))
names(plot_df) <- paste("Cluster", 1:n_clusters)
plot_df$feature_name <- rownames(plot_df)
plot_df <- pivot_longer(plot_df, cols = -feature_name)
ggplot(plot_df, aes(x = name, y = feature_name, fill = value)) +
geom_tile() +
geom_text(aes(label = round(value, 2)), color = "white") +
labs(x = "", y = "")
}
plot_mean_by_label_table <- function(your_data, label_column,
summary_func = mean,
drop_cols = c(),
color_within = c("table", "row", "col"),
round_digits = 3) {
color_within <- match.arg(color_within)
your_data[, drop_cols] <- NULL
your_data$group_name <- paste("Cluster", your_data[[label_column]])
your_data[[label_column]] <- NULL
plot_df <- your_data |>
dplyr::group_by(group_name) |>
dplyr::summarise(dplyr::across(dplyr::everything(), summary_func)) |>
tidyr::pivot_longer(-group_name)
# helper scaler to [0, 1]
scale01 <- function(x) {
rng <- range(x, na.rm = TRUE)
if (diff(rng) == 0) {
return(rep(0.5, length(x)))
}
(x - rng[1]) / diff(rng)
}
plot_df <- dplyr::mutate(plot_df,
fill_value =
dplyr::case_when(
color_within == "table" ~ value,
color_within == "col" ~ ave(value, group_name, FUN = scale01),
color_within == "row" ~ ave(value, name, FUN = scale01)
)
)
p <- ggplot(plot_df, aes(x = group_name, y = name, fill = fill_value)) +
geom_tile() +
geom_text(aes(label = round(value, round_digits)), color = "white") +
labs(x = "", y = "", fill = "") +
scale_fill_gradient(low = "#132B43", high = "#56B1F7") +
theme_minimal()
if (color_within %in% c("col", "row")) {
p <- p +
theme(legend.position = "none")
}
if (color_within == "row") {
y_breaks <- seq_along(unique(plot_df$name)) + 0.5
p <- p +
geom_hline(
yintercept = y_breaks,
color = "white",
linewidth = 0.5
)
}
if (color_within == "col") {
x_breaks <- seq_along(unique(plot_df$group_name)) + 0.5
p <- p +
geom_vline(
xintercept = x_breaks,
color = "white",
linewidth = 0.5
)
}
return(p)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment