Skip to content

Instantly share code, notes, and snippets.

@kkoomen
Created January 17, 2026 14:44
Show Gist options
  • Select an option

  • Save kkoomen/722fee9aa070eb6cc0e653633520c58d to your computer and use it in GitHub Desktop.

Select an option

Save kkoomen/722fee9aa070eb6cc0e653633520c58d to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download
from sae import SAE
from utils import SAEDataset
from sae_naming import compute_similarities
MODEL_PATH = hf_hub_download(
repo_id="WolodjaZ/MSAE",
filename="ViT-L_14/not_centered/24576_768_TopKReLU_64_RW_False_False_0.0_cc3m_ViT-L~14_train_image_2905936_768.pth"
)
model = SAE(MODEL_PATH)
VOCAB_EMB_PATH = "data/laion_unigram_ViT-L~14_-1_text_37445_768.npy"
vocab_dataset = SAEDataset(VOCAB_EMB_PATH, mean_center=False, target_norm=0.0)
similarity_matrix = compute_similarities(
model,
vocab_dataset,
patch_diff=True,
batch_size=1024,
num_workers=9
)
# Best and second-best pear neuron
best_vocab = np.argmax(similarity_matrix, axis=0)
best_score = np.max(similarity_matrix, axis=0)
masked = similarity_matrix.copy()
masked[best_vocab, np.arange(similarity_matrix.shape[1])] = -np.inf # ignore top
second_best = np.max(masked, axis=0) # second-highest score per neuron
SIMILARITY_THRESHOLD = 0.25
RATIO_THRESHOLD = 2
_, N = similarity_matrix.shape
# Counts
c_sim = 0
c_best = 0
c_sim_best = 0
c_ratio = 0
c_all = 0
for i in range(N):
passes_sim = best_score[i] > SIMILARITY_THRESHOLD
v = best_vocab[i]
is_best = np.argmax(similarity_matrix[v, :]) == i
ratio = (
best_score[i] / second_best[i]
if second_best[i] > 0 else float("inf")
)
passes_ratio = ratio > RATIO_THRESHOLD
if passes_sim:
c_sim += 1
if is_best:
c_best += 1
if passes_sim and is_best:
c_sim_best += 1
if passes_ratio:
c_ratio += 1
if passes_sim and is_best and passes_ratio:
c_all += 1
expansion_factor = model.latent_dim // model.input_dim
print("\n" + "=" * 50)
print("CONCEPT VALIDATION (Table 3)")
print("=" * 50)
print(f"Model: ViT-L/14")
print(f"Expansion: {expansion_factor}×")
print(f"Neurons: {N}")
print("-" * 50)
print(f"Similarity > {SIMILARITY_THRESHOLD:<5} {c_sim:>6}")
print(f"Best vector: {c_best:>6}")
print(f"Above and best: {c_sim_best:>6}")
print(f"Ratio > {RATIO_THRESHOLD:<4} {c_ratio:>6}")
print(f"All conditions: {c_all:>6}")
print("=" * 50)
# Example output of the above:
#
# ==================================================
# Model: ViT-L/14
# Expansion: 32×
# Neurons: 24576
# --------------------------------------------------
# Similarity > 0.25 521
# Best vector: 1671
# Above and best: 505
# Ratio > 2 19
# All conditions: 3
# ==================================================
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment