Created
July 21, 2025 16:42
-
-
Save DSamuelHodge/ecbad3ea1cefcea0afd6b9fc6d7675ca to your computer and use it in GitHub Desktop.
Code for topology filtration of attention using GUDHI.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # === Imports === | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from transformers import AutoModel, AutoTokenizer | |
| import gudhi as gd | |
| from sklearn.manifold import MDS | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # === Constants and Configurations === | |
| MODEL_NAME = "Qwen/Qwen2.5-0.5B" | |
| PROMPT_TEXT = "Topology reveals the hidden structure of attention patterns." | |
| MAX_TOKENS = 16 | |
| TARGET_LAYERS = [2, 8, 16, 21, 23] | |
| SELECTED_HEAD_POLICY = "middle" # Can extend to "first", "last", or custom | |
| EPSILONS = [0.3, 0.5, 0.8] | |
| PERSISTENCE_MAX_EDGE = 1.0 | |
| RANDOM_SEED = 42 | |
| # === Helper Functions === | |
| def load_qwen_model_and_tokenizer(model_name): | |
| """Load Qwen model and tokenizer with correct trust and fallback token setup.""" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| model = AutoModel.from_pretrained(model_name, output_attentions=True, trust_remote_code=True, attn_implementation="eager") | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| return tokenizer, model | |
| def extract_attention_from_text(model, tokenizer, text, max_tokens): | |
| """Tokenize input and extract attention tensors from Qwen2.5-0.5B.""" | |
| inputs = tokenizer(text, return_tensors="pt", max_length=max_tokens, truncation=True) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| if outputs.attentions is None: | |
| raise ValueError("No attention returned by model.") | |
| return torch.stack([a for a in outputs.attentions if a is not None]).squeeze(1).detach().cpu() | |
| def compute_embedding_from_attention(attn_matrix): | |
| """Symmetrize attention into a distance matrix and embed with MDS.""" | |
| attn = attn_matrix.numpy() | |
| dist = 1.0 - (attn + attn.T) / 2.0 | |
| return MDS(n_components=2, dissimilarity='precomputed', random_state=RANDOM_SEED).fit_transform(dist) | |
| def create_rips_complex(points, epsilon, max_dim=2): | |
| """Construct Rips complex and return its simplex tree.""" | |
| rips = gd.RipsComplex(points=points, max_edge_length=epsilon) | |
| return rips.create_simplex_tree(max_dimension=max_dim) | |
| def draw_simplicial_complex(ax, points, simplex_tree, epsilon): | |
| """Render 2D simplicial complex given simplex tree and embedding.""" | |
| vertices, edges, triangles = [], [], [] | |
| for simplex, _ in simplex_tree.get_simplices(): | |
| if len(simplex) == 1: | |
| vertices.append(simplex[0]) | |
| elif len(simplex) == 2: | |
| edges.append(simplex) | |
| elif len(simplex) == 3: | |
| triangles.append(simplex) | |
| for triangle in triangles: | |
| if all(v < len(points) for v in triangle): | |
| poly = plt.Polygon(points[triangle], alpha=0.3, facecolor='lightblue', edgecolor='blue', linewidth=1) | |
| ax.add_patch(poly) | |
| for edge in edges: | |
| if all(v < len(points) for v in edge): | |
| ax.plot(*points[edge].T, 'b-', linewidth=2, alpha=0.8) | |
| ax.scatter(points[:, 0], points[:, 1], c='darkblue', s=100, zorder=5) | |
| ax.set_title(f"Simplicial Complex (ε = {epsilon:.1f})", fontsize=12) | |
| ax.set_aspect("equal") | |
| ax.set_xticks([]), ax.set_yticks([]) | |
| ax.grid(True, alpha=0.3) | |
| def create_topological_visualization(attn_matrix, layer_idx): | |
| """Create professional visualization of attention topology and persistent homology.""" | |
| points = compute_embedding_from_attention(attn_matrix) | |
| fig = plt.figure(figsize=(16, 8)) | |
| fig.suptitle(f"Qwen2.5-0.5B Layer {layer_idx} - Topological Analysis", fontsize=16, weight='bold') | |
| for i, eps in enumerate(EPSILONS): | |
| ax = fig.add_subplot(2, 3, i + 1) | |
| st = create_rips_complex(points, eps) | |
| draw_simplicial_complex(ax, points, st, eps) | |
| # Persistence diagram (bottom row) | |
| ax_persist = fig.add_subplot(2, 3, (4, 6)) | |
| st_full = create_rips_complex(points, PERSISTENCE_MAX_EDGE) | |
| persistence = st_full.persistence() | |
| gd.plot_persistence_diagram(persistence, axes=ax_persist) | |
| ax_persist.set_title("Persistence Diagram", fontsize=14) | |
| plt.tight_layout() | |
| return fig | |
| # === Main Execution === | |
| def main(): | |
| tokenizer, model = load_qwen_model_and_tokenizer(MODEL_NAME) | |
| attention = extract_attention_from_text(model, tokenizer, PROMPT_TEXT, MAX_TOKENS) | |
| for layer in TARGET_LAYERS: | |
| if layer >= attention.shape[0]: | |
| continue | |
| head_idx = attention.shape[1] // 2 # Default: middle head | |
| attn_matrix = attention[layer, head_idx] | |
| fig = create_topological_visualization(attn_matrix, layer) | |
| fig.savefig(f'qwen_layer_{layer}_topology.png', dpi=150, bbox_inches='tight') | |
| plt.close(fig) | |
| print(f"✓ Saved: Layer {layer} topology visual") | |
| print("All visualizations complete.") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment