Created
February 22, 2026 00:16
-
-
Save alexispurslane/65bd1184f1bb8ad251da2551c1594ab7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| def calculate_inference_footprint(model_params, gpu_params, input_tokens, output_tokens): | |
| """ | |
| Calculates the hardware, energy, and time footprint for an LLM workload. | |
| """ | |
| # Units | |
| GB = 1_000_000_000 | |
| GiB = 1024**3 # Fixed to 1024^3 | |
| TB = 1_000 * GB | |
| TFLOP = 1_000_000_000_000 | |
| # 1. Model Specifics | |
| dtype_size = model_params.get('dtype_size', 1) | |
| weight_size = model_params['total_params'] * GB * dtype_size | |
| active_size = model_params['active_params'] * GB * dtype_size | |
| kv_cache_per_user = ( | |
| model_params['layers'] * model_params['context_length'] * (model_params['kv_lora_rank'] + (model_params['heads'] * model_params['qk_rope_head_dim'])) * dtype_size | |
| ) | |
| # 2. Hardware Topology (TP Calculation) | |
| gpu_ram_total = gpu_params['ram_gb'] * GiB | |
| tp_floor = weight_size / gpu_ram_total | |
| tp = 2**math.ceil(math.log2(tp_floor)) | |
| # 3. Concurrency (U_max) | |
| total_usable_ram = tp * gpu_ram_total | |
| u_max = math.floor((total_usable_ram - weight_size) / kv_cache_per_user) | |
| # 4. Throughput (TPS) | |
| total_cluster_flops = gpu_params['flops_tflop'] * TFLOP * tp | |
| flops_per_token = 2 * active_size | |
| prefill_tps = total_cluster_flops / flops_per_token | |
| cluster_bandwidth = gpu_params['bandwidth_tb_s'] * TB * tp | |
| passes_per_second = cluster_bandwidth / active_size | |
| decode_tps = passes_per_second * u_max | |
| # 5. Workload Execution | |
| cluster_seconds = (input_tokens / prefill_tps) + (output_tokens / decode_tps) | |
| # 6. Energy & Power | |
| system_watts = gpu_params["power_draw_watts"] * (tp / 8) | |
| energy_joules = system_watts * cluster_seconds | |
| energy_kwh_cluster = energy_joules / 3_600_000 | |
| energy_kwh_wall = energy_kwh_cluster * 1.4 | |
| return { | |
| "tensor_parallelism": tp, | |
| "concurrent_users_limit": u_max, | |
| "prefill_tps": round(prefill_tps, 2), | |
| "decode_tps": round(decode_tps, 2), | |
| "cluster_seconds_used": round(cluster_seconds, 2), | |
| "energy_kwh_wall": round(energy_kwh_wall, 4), | |
| "system_power_kw": system_watts / 1000 | |
| } | |
| def get_household_comparisons(kwh): | |
| """ | |
| Returns a list of household activities equivalent to the energy used. | |
| """ | |
| # Average kWh per use or per hour | |
| comparisons = { | |
| "Smartphone Charges": 0.005, # 5Wh per charge | |
| "Hours of LED TV (50\")": 0.06, # 60W per hour | |
| "Hours of Laptop Use": 0.05, # 50W per hour | |
| "Microwave Minutes (1000W)": 1/60, # ~0.016 kWh per minute | |
| "Hours of LED Lighting (10W)": 0.01, | |
| "Loads of Laundry (Cold)": 0.5, | |
| "Toasted Bread Slices": 0.02, # ~20Wh to toast 2 slices | |
| } | |
| results = {} | |
| for item, val in comparisons.items(): | |
| results[item] = round(kwh / val, 2) | |
| return results | |
| # Configuration | |
| model = { | |
| "total_params": 1000, "active_params": 32, "layers": 61, | |
| "context_length": 256000, "kv_lora_rank": 512, | |
| "qk_rope_head_dim": 64, "dtype_size": 1, "heads": 64 | |
| } | |
| gpu = { | |
| "ram_gb": 192, "bandwidth_tb_s": 8.0, | |
| "flops_tflop": 9000, "power_draw_watts": 14300 | |
| } | |
| # Run Analysis | |
| stats = calculate_inference_footprint(model, gpu, 22_900_000, 116_300) | |
| equivalents = get_household_comparisons(stats['energy_kwh_wall']) | |
| # Output | |
| print(f"--- Hardware Profile ---") | |
| print(f"Cluster Size: {stats['tensor_parallelism']} GPUs") | |
| print(f"System Power: {stats['system_power_kw']} kW") | |
| print(f"--- Performance ---") | |
| print(f"Total Time: {stats['cluster_seconds_used']}s") | |
| print(f"Energy Used: {stats['energy_kwh_wall']} kWh") | |
| print(f"\n--- Household Equivalents ---") | |
| for activity, count in equivalents.items(): | |
| print(f" • {activity.ljust(25)}: {count}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment