Skip to content

Instantly share code, notes, and snippets.

@sigridjineth
Created December 6, 2025 04:15
Show Gist options
  • Select an option

  • Save sigridjineth/def0f0b783e43a37616805faece2485c to your computer and use it in GitHub Desktop.

Select an option

Save sigridjineth/def0f0b783e43a37616805faece2485c to your computer and use it in GitHub Desktop.
REAEP

  1. REAP 공식

  REAP Score = mean(||expert_output|| × router_weight)

  - ||expert_output||: 전문가 출력의 L2 norm
  - router_weight: softmax(router_logits)에서 해당 전문가의 가중치
  - 각 전문가에 실제로 라우팅된 토큰들만 대상으로 계산

  ---
  2. 핵심 코드 구조

  REAPObserver 클래스 - Hook을 등록하여 forward pass 중 통계 수집:

  class REAPObserver:
      def __init__(self, model, model_attrs):
          self.state = {}  # layer_idx -> {reap_sum, expert_count}
          self.hooks = []
          self._register_hooks()

  ---
  3. Hook 함수 (핵심 로직)

  @torch.no_grad()
  def hook_fn(module, args, output):
      # 1. 입력 hidden states 가져오기
      hidden_states = args[0]  # (batch, seq, hidden_dim)
      flat_hidden = hidden_states.view(-1, hidden_dim)  # (total_tokens, hidden_dim)

      # 2. Router weights 계산 (softmax)
      routing_weights = F.softmax(router_logits.float(), dim=-1)

      # 3. Top-k 선택된 전문가 확인
      _, selected_experts = torch.topk(router_logits, top_k, dim=-1)

      # 4. 각 전문가별 REAP 점수 계산
      for expert_idx in range(num_experts):
          # 이 전문가에 라우팅된 토큰 찾기
          active_mask = (selected_experts == expert_idx).any(dim=-1)

          if not active_mask.any():
              continue

          # 해당 토큰들의 router weight
          active_router_weights = routing_weights[active_mask, expert_idx]

          # 전문가 출력 계산 (핵심!)
          active_hidden = flat_hidden[active_mask]
          expert_output = expert(active_hidden)  # 실제 전문가 forward

          # L2 norm 계산
          activation_norms = torch.linalg.norm(expert_output.float(), dim=-1)

          # REAP = norm × weight
          reap_scores = activation_norms * active_router_weights

          # 누적
          self.state[layer_idx]["reap_sum"][expert_idx] += reap_scores.sum()
          self.state[layer_idx]["expert_count"][expert_idx] += active_mask.sum()

  ---
  4. 최종 REAP 점수 계산

  def get_reap_scores(self):
      reap_scores = {}
      for layer_idx, layer_state in self.state.items():
          counts = layer_state["expert_count"].float()
          counts = torch.clamp(counts, min=1)  # 0으로 나누기 방지

          # 평균 REAP 점수
          reap_scores[layer_idx] = layer_state["reap_sum"] / counts
      return reap_scores

  ---
  5. 프루닝 (낮은 점수 제거)

  # REAP 점수가 가장 낮은 전문가들 선택
  _, experts_to_prune = torch.topk(scores, n_to_prune, largest=False)

  # 남길 전문가 인덱스
  retained_indices = [i for i in range(num_experts) if i not in experts_to_prune]

  # 전문가 모듈 프루닝
  new_experts = nn.ModuleList([experts[i] for i in retained_indices])
  setattr(moe_block, "experts", new_experts)

  # Router 가중치도 프루닝
  new_weight = router.weight.data[retained_indices, :]
  new_router = nn.Linear(in_features, len(retained_indices), ...)
  new_router.weight.data = new_weight

  ---
  6. 왜 REAP인가?

  | 방법              | 기준            | 문제점                  |
  |-----------------|---------------|----------------------|
  | Frequency       | 호출 빈도         | 자주 호출되지만 기여도 낮을 수 있음 |
  | Activation Norm | 출력 크기         | 라우터가 낮은 가중치 줄 수 있음   |
  | REAP            | Norm × Weight | 실제 기여도 반영            |

  REAP는 "얼마나 자주 + 얼마나 크게 + 얼마나 중요하게" 를 종합적으로 고려합니다.
import torch
import time
import gc
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import numpy as np
# ==========================================
# Configuration
# ==========================================
ORIGINAL_MODEL = "kakaocorp/kanana-1.5-15.7b-a3b-base"
PRUNED_MODEL = "./kanana-reap-10" # 10% REAP pruning (official algorithm)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Test prompts for speed benchmark (Korean + English mix)
TEST_PROMPTS = [
"인공지능의 미래에 대해 설명해주세요.",
"한국의 역사에서 가장 중요한 사건은 무엇인가요?",
"The theory of relativity states that",
"Machine learning is a subset of artificial intelligence that",
"서울에서 부산까지 가는 가장 빠른 방법은",
]
def get_gpu_memory():
"""Get current GPU memory usage in GB"""
if torch.cuda.is_available():
return torch.cuda.max_memory_allocated() / 1e9
return 0
def clear_memory():
"""Clear GPU memory"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
def load_model(model_path):
"""Load model with memory tracking"""
clear_memory()
print(f"\n📦 Loading: {model_path}")
start_time = time.time()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
load_time = time.time() - start_time
memory_used = get_gpu_memory()
print(f" ✓ Loaded in {load_time:.1f}s | GPU Memory: {memory_used:.2f} GB")
return model, tokenizer, memory_used, load_time
def benchmark_speed(model, tokenizer, prompts, max_new_tokens=100, num_runs=3):
"""Benchmark generation speed"""
print(f"\n⚡ Speed Benchmark (max_new_tokens={max_new_tokens}, runs={num_runs})")
model.eval()
total_tokens = 0
total_time = 0
with torch.no_grad():
for prompt in prompts:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Warmup
_ = model.generate(**inputs, max_new_tokens=10, do_sample=False)
# Actual benchmark
for _ in range(num_runs):
torch.cuda.synchronize() if torch.cuda.is_available() else None
start = time.time()
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
torch.cuda.synchronize() if torch.cuda.is_available() else None
elapsed = time.time() - start
generated_tokens = outputs.shape[1] - inputs['input_ids'].shape[1]
total_tokens += generated_tokens
total_time += elapsed
tokens_per_sec = total_tokens / total_time
print(f" ✓ {tokens_per_sec:.2f} tokens/sec ({total_tokens} tokens in {total_time:.2f}s)")
return tokens_per_sec
def calculate_perplexity(model, tokenizer, texts, max_length=512):
"""Calculate perplexity on given texts"""
print(f"\n📊 Calculating Perplexity...")
model.eval()
total_loss = 0
total_tokens = 0
with torch.no_grad():
for text in texts:
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
outputs = model(**inputs, labels=inputs["input_ids"])
# Loss is averaged over tokens
num_tokens = inputs["input_ids"].shape[1]
total_loss += outputs.loss.item() * num_tokens
total_tokens += num_tokens
avg_loss = total_loss / total_tokens
perplexity = np.exp(avg_loss)
print(f" ✓ Perplexity: {perplexity:.2f}")
return perplexity
def count_parameters(model):
"""Count total and MoE parameters"""
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
return total, trainable
def main():
print("=" * 60)
print("🔬 REAP Pruning Evaluation")
print("=" * 60)
# Load evaluation data
print("\n📂 Loading evaluation data...")
try:
# Try Korean dataset first
eval_ds = load_dataset("maywell/ko_wikidata_QA", split="train", streaming=True)
eval_texts = []
for i, item in enumerate(eval_ds):
if i >= 50: # Use 50 samples for perplexity
break
text = f"질문: {item.get('instruction', '')}\n답변: {item.get('output', '')}"
if len(text) > 50:
eval_texts.append(text)
print(f" ✓ Loaded {len(eval_texts)} Korean QA samples")
except Exception as e:
print(f" ⚠ Failed to load Korean data: {e}")
eval_texts = TEST_PROMPTS * 10
results = {}
# ==========================================
# Evaluate Original Model
# ==========================================
print("\n" + "=" * 60)
print("📍 ORIGINAL MODEL")
print("=" * 60)
model, tokenizer, orig_memory, orig_load_time = load_model(ORIGINAL_MODEL)
orig_params, _ = count_parameters(model)
print(f" Parameters: {orig_params / 1e9:.2f}B")
orig_speed = benchmark_speed(model, tokenizer, TEST_PROMPTS)
orig_ppl = calculate_perplexity(model, tokenizer, eval_texts[:30])
results['original'] = {
'memory_gb': orig_memory,
'load_time': orig_load_time,
'params_b': orig_params / 1e9,
'tokens_per_sec': orig_speed,
'perplexity': orig_ppl
}
# Clear memory before loading next model
del model, tokenizer
clear_memory()
# ==========================================
# Evaluate Pruned Model
# ==========================================
print("\n" + "=" * 60)
print("📍 PRUNED MODEL (10% experts removed - official REAP)")
print("=" * 60)
model, tokenizer, pruned_memory, pruned_load_time = load_model(PRUNED_MODEL)
pruned_params, _ = count_parameters(model)
print(f" Parameters: {pruned_params / 1e9:.2f}B")
pruned_speed = benchmark_speed(model, tokenizer, TEST_PROMPTS)
pruned_ppl = calculate_perplexity(model, tokenizer, eval_texts[:30])
results['pruned'] = {
'memory_gb': pruned_memory,
'load_time': pruned_load_time,
'params_b': pruned_params / 1e9,
'tokens_per_sec': pruned_speed,
'perplexity': pruned_ppl
}
# ==========================================
# Summary
# ==========================================
print("\n" + "=" * 60)
print("📈 COMPARISON SUMMARY")
print("=" * 60)
o, p = results['original'], results['pruned']
param_reduction = (1 - p['params_b'] / o['params_b']) * 100
memory_reduction = (1 - p['memory_gb'] / o['memory_gb']) * 100
speed_improvement = (p['tokens_per_sec'] / o['tokens_per_sec'] - 1) * 100
ppl_change = (p['perplexity'] / o['perplexity'] - 1) * 100
print(f"""
┌─────────────────────┬─────────────┬─────────────┬─────────────┐
│ Metric │ Original │ Pruned │ Change │
├─────────────────────┼─────────────┼─────────────┼─────────────┤
│ Parameters (B) │ {o['params_b']:>10.2f} │ {p['params_b']:>10.2f} │ {param_reduction:>+9.1f}% │
│ GPU Memory (GB) │ {o['memory_gb']:>10.2f} │ {p['memory_gb']:>10.2f} │ {memory_reduction:>+9.1f}% │
│ Speed (tok/s) │ {o['tokens_per_sec']:>10.2f} │ {p['tokens_per_sec']:>10.2f} │ {speed_improvement:>+9.1f}% │
│ Perplexity │ {o['perplexity']:>10.2f} │ {p['perplexity']:>10.2f} │ {ppl_change:>+9.1f}% │
│ Load Time (s) │ {o['load_time']:>10.1f} │ {p['load_time']:>10.1f} │ │
└─────────────────────┴─────────────┴─────────────┴─────────────┘
💡 Interpretation:
- Parameter reduction: {param_reduction:.1f}% smaller model
- Memory savings: {memory_reduction:.1f}% less GPU RAM needed
- Speed: {"faster" if speed_improvement > 0 else "slower"} by {abs(speed_improvement):.1f}%
- Quality: Perplexity {"increased" if ppl_change > 0 else "decreased"} by {abs(ppl_change):.1f}% (lower is better)
""")
# Quality assessment
if ppl_change < 5:
print("✅ Quality preserved well! (<5% perplexity increase)")
elif ppl_change < 10:
print("⚠️ Slight quality degradation (5-10% perplexity increase)")
else:
print("❌ Significant quality loss (>10% perplexity increase)")
if __name__ == "__main__":
main()
"""
REAP (Router-weighted Expert Activation Pruning) - Official Algorithm
Based on: https://github.com/CerebrasResearch/reap
REAP Score = mean(activation_norm * router_weight) for each expert
- activation_norm: L2 norm of expert OUTPUT (not input)
- router_weight: softmax(router_logits)
- Only considers tokens actually routed to each expert (top-k selection)
Experts with LOWEST REAP scores are pruned (they contribute least to output)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm
import random
import gc
from dataclasses import dataclass
from typing import Optional
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ==========================================
# Configuration
# ==========================================
@dataclass
class REAPConfig:
model_id: str = "kakaocorp/kanana-1.5-15.7b-a3b-base"
save_path: str = "./kanana-reap-10"
compression_ratio: float = 0.10 # 10% of experts to prune
num_samples: int = 256 # Number of calibration samples
max_seq_len: int = 1024
seed: int = 42
config = REAPConfig()
random.seed(config.seed)
torch.manual_seed(config.seed)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"REAP Pruning Configuration:")
print(f" Model: {config.model_id}")
print(f" Compression Ratio: {config.compression_ratio*100}%")
print(f" Calibration Samples: {config.num_samples}")
# ==========================================
# Model Attributes for Different Architectures
# ==========================================
MODEL_ATTRS = {
"MixtralForCausalLM": {
"moe_block": "block_sparse_moe",
"experts": "experts",
"router": "gate",
"num_experts": "num_local_experts",
"top_k": "num_experts_per_tok",
},
"Qwen2MoeForCausalLM": {
"moe_block": "mlp",
"experts": "experts",
"router": "gate",
"num_experts": "num_experts",
"top_k": "num_experts_per_tok",
},
}
# ==========================================
# Load Model and Tokenizer
# ==========================================
print("\nLoading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
config.model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
model.eval()
# Detect model architecture
model_class_name = model.__class__.__name__
print(f"Model class: {model_class_name}")
# Get model attributes (fallback to Mixtral-like structure)
if model_class_name in MODEL_ATTRS:
model_attrs = MODEL_ATTRS[model_class_name]
else:
# Default to Mixtral structure for Kanana
model_attrs = MODEL_ATTRS["MixtralForCausalLM"]
logger.warning(f"Using MixtralForCausalLM attributes for {model_class_name}")
# ==========================================
# Load Calibration Dataset
# ==========================================
print("\nLoading calibration dataset (evol-codealpaca-v1)...")
dataset = load_dataset("theblackcat102/evol-codealpaca-v1", split="train", streaming=True)
calib_texts = []
iterator = iter(dataset)
pbar = tqdm(total=config.num_samples, desc="Loading samples")
while len(calib_texts) < config.num_samples:
try:
item = next(iterator)
# Format as instruction-response pair
text = f"### Instruction:\n{item['instruction']}\n\n### Response:\n{item['output']}"
if len(text) > 100:
calib_texts.append(text)
pbar.update(1)
except StopIteration:
break
pbar.close()
print(f"Loaded {len(calib_texts)} calibration samples")
# ==========================================
# REAP Observer - Collect Expert Statistics
# ==========================================
class REAPObserver:
"""
Collects REAP statistics for each MoE layer.
REAP = mean(||expert_output|| * softmax(router_logits))
for tokens actually routed to each expert.
"""
def __init__(self, model, model_attrs):
self.model = model
self.model_attrs = model_attrs
self.state = {} # layer_idx -> {reap_sum, expert_count, ...}
self.hooks = []
self._register_hooks()
def _get_moe_block(self, layer):
"""Get MoE block from layer."""
return getattr(layer, self.model_attrs["moe_block"])
def _register_hooks(self):
"""Register forward hooks on MoE blocks."""
for layer_idx, layer in enumerate(self.model.model.layers):
moe_block = self._get_moe_block(layer)
if moe_block is None:
continue
# Check if this is actually an MoE layer
if not hasattr(moe_block, self.model_attrs["experts"]):
continue
experts = getattr(moe_block, self.model_attrs["experts"])
router = getattr(moe_block, self.model_attrs["router"])
num_experts = len(experts)
top_k = getattr(model.config, self.model_attrs["top_k"], 8)
# Initialize state for this layer
self.state[layer_idx] = {
"reap_sum": torch.zeros(num_experts, dtype=torch.float64),
"expert_count": torch.zeros(num_experts, dtype=torch.long),
"num_experts": num_experts,
"top_k": top_k,
}
# Create and register hook
hook_fn = self._create_hook(layer_idx, moe_block, experts, router, num_experts, top_k)
handle = moe_block.register_forward_hook(hook_fn)
self.hooks.append(handle)
logger.info(f"Registered hooks on {len(self.state)} MoE layers")
def _create_hook(self, layer_idx, moe_block, experts, router, num_experts, top_k):
"""Create a hook function for a specific layer."""
@torch.no_grad()
def hook_fn(module, args, output):
# Get input hidden states
hidden_states = args[0] # (batch, seq, hidden_dim)
# Handle different output formats
if isinstance(output, tuple):
# Mixtral returns (output, router_logits)
if len(output) >= 2:
router_logits = output[1]
else:
router_logits = output[0]
else:
router_logits = output
# Flatten batch and sequence dimensions
batch_size, seq_len, hidden_dim = hidden_states.shape
flat_hidden = hidden_states.view(-1, hidden_dim) # (total_tokens, hidden_dim)
total_tokens = flat_hidden.shape[0]
# Get router logits if not already computed
if router_logits.shape[0] != total_tokens:
router_logits = router(flat_hidden) # (total_tokens, num_experts)
# Compute routing weights (softmax)
routing_weights = F.softmax(router_logits.float(), dim=-1) # (total_tokens, num_experts)
# Get top-k selected experts
_, selected_experts = torch.topk(router_logits, top_k, dim=-1) # (total_tokens, top_k)
# Compute expert activations and REAP scores
device = flat_hidden.device
for expert_idx in range(num_experts):
# Find tokens routed to this expert
active_mask = (selected_experts == expert_idx).any(dim=-1) # (total_tokens,)
if not active_mask.any():
continue
# Get router weights for active tokens
active_router_weights = routing_weights[active_mask, expert_idx] # (n_active,)
# Compute expert output for active tokens
active_hidden = flat_hidden[active_mask] # (n_active, hidden_dim)
expert = experts[expert_idx]
expert_output = expert(active_hidden) # (n_active, hidden_dim)
# Compute activation norms (L2)
activation_norms = torch.linalg.norm(expert_output.float(), dim=-1) # (n_active,)
# REAP score = activation_norm * router_weight
reap_scores = activation_norms * active_router_weights # (n_active,)
# Accumulate
self.state[layer_idx]["reap_sum"][expert_idx] += reap_scores.sum().cpu()
self.state[layer_idx]["expert_count"][expert_idx] += active_mask.sum().cpu()
return hook_fn
def get_reap_scores(self):
"""Compute final REAP scores (mean) for each layer."""
reap_scores = {}
for layer_idx, layer_state in self.state.items():
counts = layer_state["expert_count"].float()
counts = torch.clamp(counts, min=1) # Avoid division by zero
reap_scores[layer_idx] = (layer_state["reap_sum"] / counts).float()
return reap_scores
def close(self):
"""Remove all hooks."""
for handle in self.hooks:
handle.remove()
self.hooks = []
# ==========================================
# Run Calibration
# ==========================================
print("\nRunning REAP calibration...")
observer = REAPObserver(model, model_attrs)
with torch.no_grad():
for text in tqdm(calib_texts, desc="Calibration"):
try:
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=config.max_seq_len
).to(model.device)
model(**inputs)
except RuntimeError as e:
if "out of memory" in str(e):
torch.cuda.empty_cache()
continue
raise
# Get REAP scores
reap_scores = observer.get_reap_scores()
observer.close()
torch.cuda.empty_cache()
print("\nREAP scores computed for each layer:")
for layer_idx, scores in reap_scores.items():
print(f" Layer {layer_idx}: min={scores.min():.4f}, max={scores.max():.4f}, mean={scores.mean():.4f}")
# ==========================================
# Pruning
# ==========================================
print(f"\nPruning {config.compression_ratio*100}% of experts per layer...")
total_pruned = 0
for layer_idx, scores in tqdm(reap_scores.items(), desc="Pruning layers"):
layer = model.model.layers[layer_idx]
moe_block = getattr(layer, model_attrs["moe_block"])
experts = getattr(moe_block, model_attrs["experts"])
router = getattr(moe_block, model_attrs["router"])
num_experts = len(experts)
n_to_prune = int(num_experts * config.compression_ratio)
n_to_keep = num_experts - n_to_prune
if n_to_keep < 1:
n_to_keep = 1
n_to_prune = num_experts - 1
# Select experts with LOWEST REAP scores to prune
_, experts_to_prune = torch.topk(scores, n_to_prune, largest=False)
retained_indices = [i for i in range(num_experts) if i not in experts_to_prune.tolist()]
retained_indices = sorted(retained_indices)
print(f" Layer {layer_idx}: {num_experts} -> {n_to_keep} experts (pruning indices: {experts_to_prune.tolist()})")
# Prune experts
new_experts = nn.ModuleList([experts[i] for i in retained_indices])
setattr(moe_block, model_attrs["experts"], new_experts)
# Prune router weights
retained_indices_tensor = torch.tensor(retained_indices, device=router.weight.device)
new_weight = router.weight.data[retained_indices_tensor, :]
new_router = nn.Linear(
router.in_features,
n_to_keep,
bias=router.bias is not None,
device=router.weight.device,
dtype=router.weight.dtype
)
new_router.weight.data = new_weight
if router.bias is not None:
new_router.bias.data = router.bias.data[retained_indices_tensor]
setattr(moe_block, model_attrs["router"], new_router)
# Update metadata
if hasattr(moe_block, "num_experts"):
moe_block.num_experts = n_to_keep
if hasattr(moe_block, "num_local_experts"):
moe_block.num_local_experts = n_to_keep
total_pruned += n_to_prune
# Update model config
if hasattr(model.config, "num_local_experts"):
original_experts = model.config.num_local_experts
model.config.num_local_experts = int(original_experts * (1 - config.compression_ratio))
if hasattr(model.config, "num_experts"):
original_experts = model.config.num_experts
model.config.num_experts = int(original_experts * (1 - config.compression_ratio))
# ==========================================
# Save Pruned Model
# ==========================================
print(f"\nSaving pruned model to {config.save_path}...")
model.save_pretrained(config.save_path)
tokenizer.save_pretrained(config.save_path)
print(f"""
REAP Pruning Complete!
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Model: {config.model_id}
Compression Ratio: {config.compression_ratio*100}%
Total Experts Pruned: {total_pruned}
Calibration Samples: {len(calib_texts)}
Saved to: {config.save_path}
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
""")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment