1. REAP 공식
REAP Score = mean(||expert_output|| × router_weight)
- ||expert_output||: 전문가 출력의 L2 norm
- router_weight: softmax(router_logits)에서 해당 전문가의 가중치
- 각 전문가에 실제로 라우팅된 토큰들만 대상으로 계산
---
2. 핵심 코드 구조
REAPObserver 클래스 - Hook을 등록하여 forward pass 중 통계 수집:
class REAPObserver:
def __init__(self, model, model_attrs):
self.state = {} # layer_idx -> {reap_sum, expert_count}
self.hooks = []
self._register_hooks()
---
3. Hook 함수 (핵심 로직)
@torch.no_grad()
def hook_fn(module, args, output):
# 1. 입력 hidden states 가져오기
hidden_states = args[0] # (batch, seq, hidden_dim)
flat_hidden = hidden_states.view(-1, hidden_dim) # (total_tokens, hidden_dim)
# 2. Router weights 계산 (softmax)
routing_weights = F.softmax(router_logits.float(), dim=-1)
# 3. Top-k 선택된 전문가 확인
_, selected_experts = torch.topk(router_logits, top_k, dim=-1)
# 4. 각 전문가별 REAP 점수 계산
for expert_idx in range(num_experts):
# 이 전문가에 라우팅된 토큰 찾기
active_mask = (selected_experts == expert_idx).any(dim=-1)
if not active_mask.any():
continue
# 해당 토큰들의 router weight
active_router_weights = routing_weights[active_mask, expert_idx]
# 전문가 출력 계산 (핵심!)
active_hidden = flat_hidden[active_mask]
expert_output = expert(active_hidden) # 실제 전문가 forward
# L2 norm 계산
activation_norms = torch.linalg.norm(expert_output.float(), dim=-1)
# REAP = norm × weight
reap_scores = activation_norms * active_router_weights
# 누적
self.state[layer_idx]["reap_sum"][expert_idx] += reap_scores.sum()
self.state[layer_idx]["expert_count"][expert_idx] += active_mask.sum()
---
4. 최종 REAP 점수 계산
def get_reap_scores(self):
reap_scores = {}
for layer_idx, layer_state in self.state.items():
counts = layer_state["expert_count"].float()
counts = torch.clamp(counts, min=1) # 0으로 나누기 방지
# 평균 REAP 점수
reap_scores[layer_idx] = layer_state["reap_sum"] / counts
return reap_scores
---
5. 프루닝 (낮은 점수 제거)
# REAP 점수가 가장 낮은 전문가들 선택
_, experts_to_prune = torch.topk(scores, n_to_prune, largest=False)
# 남길 전문가 인덱스
retained_indices = [i for i in range(num_experts) if i not in experts_to_prune]
# 전문가 모듈 프루닝
new_experts = nn.ModuleList([experts[i] for i in retained_indices])
setattr(moe_block, "experts", new_experts)
# Router 가중치도 프루닝
new_weight = router.weight.data[retained_indices, :]
new_router = nn.Linear(in_features, len(retained_indices), ...)
new_router.weight.data = new_weight
---
6. 왜 REAP인가?
| 방법 | 기준 | 문제점 |
|-----------------|---------------|----------------------|
| Frequency | 호출 빈도 | 자주 호출되지만 기여도 낮을 수 있음 |
| Activation Norm | 출력 크기 | 라우터가 낮은 가중치 줄 수 있음 |
| REAP | Norm × Weight | 실제 기여도 반영 |
REAP는 "얼마나 자주 + 얼마나 크게 + 얼마나 중요하게" 를 종합적으로 고려합니다.
Created
December 6, 2025 04:15
-
-
Save sigridjineth/def0f0b783e43a37616805faece2485c to your computer and use it in GitHub Desktop.
REAEP
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import time | |
| import gc | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from datasets import load_dataset | |
| import numpy as np | |
| # ========================================== | |
| # Configuration | |
| # ========================================== | |
| ORIGINAL_MODEL = "kakaocorp/kanana-1.5-15.7b-a3b-base" | |
| PRUNED_MODEL = "./kanana-reap-10" # 10% REAP pruning (official algorithm) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Test prompts for speed benchmark (Korean + English mix) | |
| TEST_PROMPTS = [ | |
| "인공지능의 미래에 대해 설명해주세요.", | |
| "한국의 역사에서 가장 중요한 사건은 무엇인가요?", | |
| "The theory of relativity states that", | |
| "Machine learning is a subset of artificial intelligence that", | |
| "서울에서 부산까지 가는 가장 빠른 방법은", | |
| ] | |
| def get_gpu_memory(): | |
| """Get current GPU memory usage in GB""" | |
| if torch.cuda.is_available(): | |
| return torch.cuda.max_memory_allocated() / 1e9 | |
| return 0 | |
| def clear_memory(): | |
| """Clear GPU memory""" | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.reset_peak_memory_stats() | |
| def load_model(model_path): | |
| """Load model with memory tracking""" | |
| clear_memory() | |
| print(f"\n📦 Loading: {model_path}") | |
| start_time = time.time() | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| load_time = time.time() - start_time | |
| memory_used = get_gpu_memory() | |
| print(f" ✓ Loaded in {load_time:.1f}s | GPU Memory: {memory_used:.2f} GB") | |
| return model, tokenizer, memory_used, load_time | |
| def benchmark_speed(model, tokenizer, prompts, max_new_tokens=100, num_runs=3): | |
| """Benchmark generation speed""" | |
| print(f"\n⚡ Speed Benchmark (max_new_tokens={max_new_tokens}, runs={num_runs})") | |
| model.eval() | |
| total_tokens = 0 | |
| total_time = 0 | |
| with torch.no_grad(): | |
| for prompt in prompts: | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| # Warmup | |
| _ = model.generate(**inputs, max_new_tokens=10, do_sample=False) | |
| # Actual benchmark | |
| for _ in range(num_runs): | |
| torch.cuda.synchronize() if torch.cuda.is_available() else None | |
| start = time.time() | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| torch.cuda.synchronize() if torch.cuda.is_available() else None | |
| elapsed = time.time() - start | |
| generated_tokens = outputs.shape[1] - inputs['input_ids'].shape[1] | |
| total_tokens += generated_tokens | |
| total_time += elapsed | |
| tokens_per_sec = total_tokens / total_time | |
| print(f" ✓ {tokens_per_sec:.2f} tokens/sec ({total_tokens} tokens in {total_time:.2f}s)") | |
| return tokens_per_sec | |
| def calculate_perplexity(model, tokenizer, texts, max_length=512): | |
| """Calculate perplexity on given texts""" | |
| print(f"\n📊 Calculating Perplexity...") | |
| model.eval() | |
| total_loss = 0 | |
| total_tokens = 0 | |
| with torch.no_grad(): | |
| for text in texts: | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| outputs = model(**inputs, labels=inputs["input_ids"]) | |
| # Loss is averaged over tokens | |
| num_tokens = inputs["input_ids"].shape[1] | |
| total_loss += outputs.loss.item() * num_tokens | |
| total_tokens += num_tokens | |
| avg_loss = total_loss / total_tokens | |
| perplexity = np.exp(avg_loss) | |
| print(f" ✓ Perplexity: {perplexity:.2f}") | |
| return perplexity | |
| def count_parameters(model): | |
| """Count total and MoE parameters""" | |
| total = sum(p.numel() for p in model.parameters()) | |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| return total, trainable | |
| def main(): | |
| print("=" * 60) | |
| print("🔬 REAP Pruning Evaluation") | |
| print("=" * 60) | |
| # Load evaluation data | |
| print("\n📂 Loading evaluation data...") | |
| try: | |
| # Try Korean dataset first | |
| eval_ds = load_dataset("maywell/ko_wikidata_QA", split="train", streaming=True) | |
| eval_texts = [] | |
| for i, item in enumerate(eval_ds): | |
| if i >= 50: # Use 50 samples for perplexity | |
| break | |
| text = f"질문: {item.get('instruction', '')}\n답변: {item.get('output', '')}" | |
| if len(text) > 50: | |
| eval_texts.append(text) | |
| print(f" ✓ Loaded {len(eval_texts)} Korean QA samples") | |
| except Exception as e: | |
| print(f" ⚠ Failed to load Korean data: {e}") | |
| eval_texts = TEST_PROMPTS * 10 | |
| results = {} | |
| # ========================================== | |
| # Evaluate Original Model | |
| # ========================================== | |
| print("\n" + "=" * 60) | |
| print("📍 ORIGINAL MODEL") | |
| print("=" * 60) | |
| model, tokenizer, orig_memory, orig_load_time = load_model(ORIGINAL_MODEL) | |
| orig_params, _ = count_parameters(model) | |
| print(f" Parameters: {orig_params / 1e9:.2f}B") | |
| orig_speed = benchmark_speed(model, tokenizer, TEST_PROMPTS) | |
| orig_ppl = calculate_perplexity(model, tokenizer, eval_texts[:30]) | |
| results['original'] = { | |
| 'memory_gb': orig_memory, | |
| 'load_time': orig_load_time, | |
| 'params_b': orig_params / 1e9, | |
| 'tokens_per_sec': orig_speed, | |
| 'perplexity': orig_ppl | |
| } | |
| # Clear memory before loading next model | |
| del model, tokenizer | |
| clear_memory() | |
| # ========================================== | |
| # Evaluate Pruned Model | |
| # ========================================== | |
| print("\n" + "=" * 60) | |
| print("📍 PRUNED MODEL (10% experts removed - official REAP)") | |
| print("=" * 60) | |
| model, tokenizer, pruned_memory, pruned_load_time = load_model(PRUNED_MODEL) | |
| pruned_params, _ = count_parameters(model) | |
| print(f" Parameters: {pruned_params / 1e9:.2f}B") | |
| pruned_speed = benchmark_speed(model, tokenizer, TEST_PROMPTS) | |
| pruned_ppl = calculate_perplexity(model, tokenizer, eval_texts[:30]) | |
| results['pruned'] = { | |
| 'memory_gb': pruned_memory, | |
| 'load_time': pruned_load_time, | |
| 'params_b': pruned_params / 1e9, | |
| 'tokens_per_sec': pruned_speed, | |
| 'perplexity': pruned_ppl | |
| } | |
| # ========================================== | |
| # Summary | |
| # ========================================== | |
| print("\n" + "=" * 60) | |
| print("📈 COMPARISON SUMMARY") | |
| print("=" * 60) | |
| o, p = results['original'], results['pruned'] | |
| param_reduction = (1 - p['params_b'] / o['params_b']) * 100 | |
| memory_reduction = (1 - p['memory_gb'] / o['memory_gb']) * 100 | |
| speed_improvement = (p['tokens_per_sec'] / o['tokens_per_sec'] - 1) * 100 | |
| ppl_change = (p['perplexity'] / o['perplexity'] - 1) * 100 | |
| print(f""" | |
| ┌─────────────────────┬─────────────┬─────────────┬─────────────┐ | |
| │ Metric │ Original │ Pruned │ Change │ | |
| ├─────────────────────┼─────────────┼─────────────┼─────────────┤ | |
| │ Parameters (B) │ {o['params_b']:>10.2f} │ {p['params_b']:>10.2f} │ {param_reduction:>+9.1f}% │ | |
| │ GPU Memory (GB) │ {o['memory_gb']:>10.2f} │ {p['memory_gb']:>10.2f} │ {memory_reduction:>+9.1f}% │ | |
| │ Speed (tok/s) │ {o['tokens_per_sec']:>10.2f} │ {p['tokens_per_sec']:>10.2f} │ {speed_improvement:>+9.1f}% │ | |
| │ Perplexity │ {o['perplexity']:>10.2f} │ {p['perplexity']:>10.2f} │ {ppl_change:>+9.1f}% │ | |
| │ Load Time (s) │ {o['load_time']:>10.1f} │ {p['load_time']:>10.1f} │ │ | |
| └─────────────────────┴─────────────┴─────────────┴─────────────┘ | |
| 💡 Interpretation: | |
| - Parameter reduction: {param_reduction:.1f}% smaller model | |
| - Memory savings: {memory_reduction:.1f}% less GPU RAM needed | |
| - Speed: {"faster" if speed_improvement > 0 else "slower"} by {abs(speed_improvement):.1f}% | |
| - Quality: Perplexity {"increased" if ppl_change > 0 else "decreased"} by {abs(ppl_change):.1f}% (lower is better) | |
| """) | |
| # Quality assessment | |
| if ppl_change < 5: | |
| print("✅ Quality preserved well! (<5% perplexity increase)") | |
| elif ppl_change < 10: | |
| print("⚠️ Slight quality degradation (5-10% perplexity increase)") | |
| else: | |
| print("❌ Significant quality loss (>10% perplexity increase)") | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| REAP (Router-weighted Expert Activation Pruning) - Official Algorithm | |
| Based on: https://github.com/CerebrasResearch/reap | |
| REAP Score = mean(activation_norm * router_weight) for each expert | |
| - activation_norm: L2 norm of expert OUTPUT (not input) | |
| - router_weight: softmax(router_logits) | |
| - Only considers tokens actually routed to each expert (top-k selection) | |
| Experts with LOWEST REAP scores are pruned (they contribute least to output) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from datasets import load_dataset | |
| from tqdm import tqdm | |
| import random | |
| import gc | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ========================================== | |
| # Configuration | |
| # ========================================== | |
| @dataclass | |
| class REAPConfig: | |
| model_id: str = "kakaocorp/kanana-1.5-15.7b-a3b-base" | |
| save_path: str = "./kanana-reap-10" | |
| compression_ratio: float = 0.10 # 10% of experts to prune | |
| num_samples: int = 256 # Number of calibration samples | |
| max_seq_len: int = 1024 | |
| seed: int = 42 | |
| config = REAPConfig() | |
| random.seed(config.seed) | |
| torch.manual_seed(config.seed) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"REAP Pruning Configuration:") | |
| print(f" Model: {config.model_id}") | |
| print(f" Compression Ratio: {config.compression_ratio*100}%") | |
| print(f" Calibration Samples: {config.num_samples}") | |
| # ========================================== | |
| # Model Attributes for Different Architectures | |
| # ========================================== | |
| MODEL_ATTRS = { | |
| "MixtralForCausalLM": { | |
| "moe_block": "block_sparse_moe", | |
| "experts": "experts", | |
| "router": "gate", | |
| "num_experts": "num_local_experts", | |
| "top_k": "num_experts_per_tok", | |
| }, | |
| "Qwen2MoeForCausalLM": { | |
| "moe_block": "mlp", | |
| "experts": "experts", | |
| "router": "gate", | |
| "num_experts": "num_experts", | |
| "top_k": "num_experts_per_tok", | |
| }, | |
| } | |
| # ========================================== | |
| # Load Model and Tokenizer | |
| # ========================================== | |
| print("\nLoading model and tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(config.model_id, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| config.model_id, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| model.eval() | |
| # Detect model architecture | |
| model_class_name = model.__class__.__name__ | |
| print(f"Model class: {model_class_name}") | |
| # Get model attributes (fallback to Mixtral-like structure) | |
| if model_class_name in MODEL_ATTRS: | |
| model_attrs = MODEL_ATTRS[model_class_name] | |
| else: | |
| # Default to Mixtral structure for Kanana | |
| model_attrs = MODEL_ATTRS["MixtralForCausalLM"] | |
| logger.warning(f"Using MixtralForCausalLM attributes for {model_class_name}") | |
| # ========================================== | |
| # Load Calibration Dataset | |
| # ========================================== | |
| print("\nLoading calibration dataset (evol-codealpaca-v1)...") | |
| dataset = load_dataset("theblackcat102/evol-codealpaca-v1", split="train", streaming=True) | |
| calib_texts = [] | |
| iterator = iter(dataset) | |
| pbar = tqdm(total=config.num_samples, desc="Loading samples") | |
| while len(calib_texts) < config.num_samples: | |
| try: | |
| item = next(iterator) | |
| # Format as instruction-response pair | |
| text = f"### Instruction:\n{item['instruction']}\n\n### Response:\n{item['output']}" | |
| if len(text) > 100: | |
| calib_texts.append(text) | |
| pbar.update(1) | |
| except StopIteration: | |
| break | |
| pbar.close() | |
| print(f"Loaded {len(calib_texts)} calibration samples") | |
| # ========================================== | |
| # REAP Observer - Collect Expert Statistics | |
| # ========================================== | |
| class REAPObserver: | |
| """ | |
| Collects REAP statistics for each MoE layer. | |
| REAP = mean(||expert_output|| * softmax(router_logits)) | |
| for tokens actually routed to each expert. | |
| """ | |
| def __init__(self, model, model_attrs): | |
| self.model = model | |
| self.model_attrs = model_attrs | |
| self.state = {} # layer_idx -> {reap_sum, expert_count, ...} | |
| self.hooks = [] | |
| self._register_hooks() | |
| def _get_moe_block(self, layer): | |
| """Get MoE block from layer.""" | |
| return getattr(layer, self.model_attrs["moe_block"]) | |
| def _register_hooks(self): | |
| """Register forward hooks on MoE blocks.""" | |
| for layer_idx, layer in enumerate(self.model.model.layers): | |
| moe_block = self._get_moe_block(layer) | |
| if moe_block is None: | |
| continue | |
| # Check if this is actually an MoE layer | |
| if not hasattr(moe_block, self.model_attrs["experts"]): | |
| continue | |
| experts = getattr(moe_block, self.model_attrs["experts"]) | |
| router = getattr(moe_block, self.model_attrs["router"]) | |
| num_experts = len(experts) | |
| top_k = getattr(model.config, self.model_attrs["top_k"], 8) | |
| # Initialize state for this layer | |
| self.state[layer_idx] = { | |
| "reap_sum": torch.zeros(num_experts, dtype=torch.float64), | |
| "expert_count": torch.zeros(num_experts, dtype=torch.long), | |
| "num_experts": num_experts, | |
| "top_k": top_k, | |
| } | |
| # Create and register hook | |
| hook_fn = self._create_hook(layer_idx, moe_block, experts, router, num_experts, top_k) | |
| handle = moe_block.register_forward_hook(hook_fn) | |
| self.hooks.append(handle) | |
| logger.info(f"Registered hooks on {len(self.state)} MoE layers") | |
| def _create_hook(self, layer_idx, moe_block, experts, router, num_experts, top_k): | |
| """Create a hook function for a specific layer.""" | |
| @torch.no_grad() | |
| def hook_fn(module, args, output): | |
| # Get input hidden states | |
| hidden_states = args[0] # (batch, seq, hidden_dim) | |
| # Handle different output formats | |
| if isinstance(output, tuple): | |
| # Mixtral returns (output, router_logits) | |
| if len(output) >= 2: | |
| router_logits = output[1] | |
| else: | |
| router_logits = output[0] | |
| else: | |
| router_logits = output | |
| # Flatten batch and sequence dimensions | |
| batch_size, seq_len, hidden_dim = hidden_states.shape | |
| flat_hidden = hidden_states.view(-1, hidden_dim) # (total_tokens, hidden_dim) | |
| total_tokens = flat_hidden.shape[0] | |
| # Get router logits if not already computed | |
| if router_logits.shape[0] != total_tokens: | |
| router_logits = router(flat_hidden) # (total_tokens, num_experts) | |
| # Compute routing weights (softmax) | |
| routing_weights = F.softmax(router_logits.float(), dim=-1) # (total_tokens, num_experts) | |
| # Get top-k selected experts | |
| _, selected_experts = torch.topk(router_logits, top_k, dim=-1) # (total_tokens, top_k) | |
| # Compute expert activations and REAP scores | |
| device = flat_hidden.device | |
| for expert_idx in range(num_experts): | |
| # Find tokens routed to this expert | |
| active_mask = (selected_experts == expert_idx).any(dim=-1) # (total_tokens,) | |
| if not active_mask.any(): | |
| continue | |
| # Get router weights for active tokens | |
| active_router_weights = routing_weights[active_mask, expert_idx] # (n_active,) | |
| # Compute expert output for active tokens | |
| active_hidden = flat_hidden[active_mask] # (n_active, hidden_dim) | |
| expert = experts[expert_idx] | |
| expert_output = expert(active_hidden) # (n_active, hidden_dim) | |
| # Compute activation norms (L2) | |
| activation_norms = torch.linalg.norm(expert_output.float(), dim=-1) # (n_active,) | |
| # REAP score = activation_norm * router_weight | |
| reap_scores = activation_norms * active_router_weights # (n_active,) | |
| # Accumulate | |
| self.state[layer_idx]["reap_sum"][expert_idx] += reap_scores.sum().cpu() | |
| self.state[layer_idx]["expert_count"][expert_idx] += active_mask.sum().cpu() | |
| return hook_fn | |
| def get_reap_scores(self): | |
| """Compute final REAP scores (mean) for each layer.""" | |
| reap_scores = {} | |
| for layer_idx, layer_state in self.state.items(): | |
| counts = layer_state["expert_count"].float() | |
| counts = torch.clamp(counts, min=1) # Avoid division by zero | |
| reap_scores[layer_idx] = (layer_state["reap_sum"] / counts).float() | |
| return reap_scores | |
| def close(self): | |
| """Remove all hooks.""" | |
| for handle in self.hooks: | |
| handle.remove() | |
| self.hooks = [] | |
| # ========================================== | |
| # Run Calibration | |
| # ========================================== | |
| print("\nRunning REAP calibration...") | |
| observer = REAPObserver(model, model_attrs) | |
| with torch.no_grad(): | |
| for text in tqdm(calib_texts, desc="Calibration"): | |
| try: | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=config.max_seq_len | |
| ).to(model.device) | |
| model(**inputs) | |
| except RuntimeError as e: | |
| if "out of memory" in str(e): | |
| torch.cuda.empty_cache() | |
| continue | |
| raise | |
| # Get REAP scores | |
| reap_scores = observer.get_reap_scores() | |
| observer.close() | |
| torch.cuda.empty_cache() | |
| print("\nREAP scores computed for each layer:") | |
| for layer_idx, scores in reap_scores.items(): | |
| print(f" Layer {layer_idx}: min={scores.min():.4f}, max={scores.max():.4f}, mean={scores.mean():.4f}") | |
| # ========================================== | |
| # Pruning | |
| # ========================================== | |
| print(f"\nPruning {config.compression_ratio*100}% of experts per layer...") | |
| total_pruned = 0 | |
| for layer_idx, scores in tqdm(reap_scores.items(), desc="Pruning layers"): | |
| layer = model.model.layers[layer_idx] | |
| moe_block = getattr(layer, model_attrs["moe_block"]) | |
| experts = getattr(moe_block, model_attrs["experts"]) | |
| router = getattr(moe_block, model_attrs["router"]) | |
| num_experts = len(experts) | |
| n_to_prune = int(num_experts * config.compression_ratio) | |
| n_to_keep = num_experts - n_to_prune | |
| if n_to_keep < 1: | |
| n_to_keep = 1 | |
| n_to_prune = num_experts - 1 | |
| # Select experts with LOWEST REAP scores to prune | |
| _, experts_to_prune = torch.topk(scores, n_to_prune, largest=False) | |
| retained_indices = [i for i in range(num_experts) if i not in experts_to_prune.tolist()] | |
| retained_indices = sorted(retained_indices) | |
| print(f" Layer {layer_idx}: {num_experts} -> {n_to_keep} experts (pruning indices: {experts_to_prune.tolist()})") | |
| # Prune experts | |
| new_experts = nn.ModuleList([experts[i] for i in retained_indices]) | |
| setattr(moe_block, model_attrs["experts"], new_experts) | |
| # Prune router weights | |
| retained_indices_tensor = torch.tensor(retained_indices, device=router.weight.device) | |
| new_weight = router.weight.data[retained_indices_tensor, :] | |
| new_router = nn.Linear( | |
| router.in_features, | |
| n_to_keep, | |
| bias=router.bias is not None, | |
| device=router.weight.device, | |
| dtype=router.weight.dtype | |
| ) | |
| new_router.weight.data = new_weight | |
| if router.bias is not None: | |
| new_router.bias.data = router.bias.data[retained_indices_tensor] | |
| setattr(moe_block, model_attrs["router"], new_router) | |
| # Update metadata | |
| if hasattr(moe_block, "num_experts"): | |
| moe_block.num_experts = n_to_keep | |
| if hasattr(moe_block, "num_local_experts"): | |
| moe_block.num_local_experts = n_to_keep | |
| total_pruned += n_to_prune | |
| # Update model config | |
| if hasattr(model.config, "num_local_experts"): | |
| original_experts = model.config.num_local_experts | |
| model.config.num_local_experts = int(original_experts * (1 - config.compression_ratio)) | |
| if hasattr(model.config, "num_experts"): | |
| original_experts = model.config.num_experts | |
| model.config.num_experts = int(original_experts * (1 - config.compression_ratio)) | |
| # ========================================== | |
| # Save Pruned Model | |
| # ========================================== | |
| print(f"\nSaving pruned model to {config.save_path}...") | |
| model.save_pretrained(config.save_path) | |
| tokenizer.save_pretrained(config.save_path) | |
| print(f""" | |
| REAP Pruning Complete! | |
| ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ | |
| Model: {config.model_id} | |
| Compression Ratio: {config.compression_ratio*100}% | |
| Total Experts Pruned: {total_pruned} | |
| Calibration Samples: {len(calib_texts)} | |
| Saved to: {config.save_path} | |
| ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ | |
| """) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment