Created
October 3, 2025 10:58
-
-
Save NTT123/95ac184277b4f7a7c2fb844bb7582027 to your computer and use it in GitHub Desktop.
Benchmark pytorch matrix multiplication with locked GPU clock for stable performance.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Benchmark matrix multiplication with locked GPU clock for stable performance. | |
| Requires: pip install nvidia-ml-py torch numpy | |
| """ | |
| import pynvml | |
| import torch | |
| import random | |
| import os | |
| import numpy as np | |
| from torch.profiler import profile, ProfilerActivity, schedule | |
| torch.manual_seed(0) | |
| random.seed(0) | |
| np.random.seed(0) | |
| os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" | |
| torch.use_deterministic_algorithms(True) | |
| class GPUClockLocker: | |
| """Lock GPU clocks to TDP base frequency for stable performance.""" | |
| def __init__(self, device_index=0, enabled=True): | |
| """ | |
| Args: | |
| device_index: GPU device index (default 0) | |
| enabled: Whether to actually lock clocks (useful for disabling) | |
| """ | |
| self.device_index = device_index | |
| self.handle = None | |
| self.enabled = enabled | |
| def __enter__(self): | |
| if not self.enabled: | |
| return self | |
| pynvml.nvmlInit() | |
| self.handle = pynvml.nvmlDeviceGetHandleByIndex(self.device_index) | |
| # Lock to TDP base frequency | |
| pynvml.nvmlDeviceSetGpuLockedClocks( | |
| self.handle, | |
| pynvml.NVML_CLOCK_LIMIT_ID_TDP, | |
| pynvml.NVML_CLOCK_LIMIT_ID_TDP | |
| ) | |
| print("GPU clocks locked to TDP frequency.") | |
| print("Run `nvidia-smi dmon -s pucvmet` to monitor GPU clock and power consumption.") | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| if not self.enabled: | |
| return False | |
| pynvml.nvmlShutdown() | |
| return False | |
| def benchmark(fn, workspace_generator, num_warmup_runs=1000, num_active_runs=50, num_workspaces=50): | |
| """ | |
| Benchmark a function with GPU clock locking for stable measurements. | |
| Args: | |
| fn: Function to benchmark (takes workspace arguments) | |
| workspace_generator: Function that generates workspace data | |
| num_warmup_runs: Number of warmup iterations | |
| num_active_runs: Number of measured iterations | |
| num_workspaces: Number of pre-generated workspaces to cycle through | |
| """ | |
| with GPUClockLocker() as gpu_lock: | |
| N = num_warmup_runs + num_active_runs | |
| workspaces = [workspace_generator() for _ in range(num_workspaces)] | |
| torch.cuda.synchronize() | |
| with profile(activities=[ProfilerActivity.CUDA]) as prof: | |
| for i in range(N): | |
| workspace = workspaces[i % num_workspaces] | |
| fn(*workspace) | |
| prof.step() | |
| torch.cuda.synchronize() | |
| # Collect kernel durations | |
| kernel_durations = {} | |
| for e in prof.events(): | |
| if e.device_type.name == "CUDA": | |
| kernel_name = e.name | |
| if kernel_name not in kernel_durations: | |
| kernel_durations[kernel_name] = [] | |
| kernel_durations[kernel_name].append(e.device_time) | |
| # Print statistics | |
| print("\nBenchmark Results:") | |
| print("-" * 80) | |
| for kernel_name, durations in kernel_durations.items(): | |
| # only keep active runs | |
| durations = durations[-num_active_runs:] | |
| avg = np.mean(durations) | |
| mae = np.mean(np.abs(durations - avg)) # Mean Absolute Error | |
| print(f"{kernel_name}") | |
| print(f" Duration: {avg/1e3:.5f} ms ± {mae:.2f} μs") | |
| print("-" * 80) | |
| def generate_workspace(m=4096, n=4096, k=4096, dtype=torch.float32, device='cuda'): | |
| """Generate workspace for matrix multiplication: C = A @ B""" | |
| a = torch.zeros(m, k, dtype=dtype, device=device) | |
| b = torch.zeros(k, n, dtype=dtype, device=device) | |
| c = torch.zeros(m, n, dtype=dtype, device=device) | |
| return a, b, c | |
| if __name__ == "__main__": | |
| print("Matrix Multiplication Benchmark with GPU Clock Locking") | |
| print("=" * 80) | |
| device = torch.cuda.current_device() | |
| device_name = torch.cuda.get_device_name(device) | |
| print(f"Using device: {device_name}") | |
| print() | |
| # Run benchmark | |
| benchmark( | |
| fn=lambda a, b, c: torch.matmul(a, b, out=c), | |
| workspace_generator=generate_workspace, | |
| num_warmup_runs=1000, | |
| num_active_runs=100, | |
| num_workspaces=50 | |
| ) | |
| # RTX 5090 | |
| # Duration: 2.44506 ms ± 2.00 μs | |
| # Duration: 2.44465 ms ± 2.17 μs | |
| # Duration: 2.44447 ms ± 2.40 μs | |
| # Duration: 2.44407 ms ± 2.20 μs | |
| # Duration: 2.44395 ms ± 1.91 μs | |
| # | |
| # A40 | |
| # Duration: 7.13419 ms ± 1.81 μs | |
| # Duration: 7.13409 ms ± 1.70 μs | |
| # Duration: 7.13443 ms ± 1.74 μs | |
| # Duration: 7.13429 ms ± 1.81 μs | |
| # Duration: 7.13432 ms ± 1.47 μs | |
| # | |
| # A100 | |
| # Duration: 8.79951 ms ± 1.54 μs | |
| # Duration: 8.79955 ms ± 1.33 μs | |
| # Duration: 8.79935 ms ± 1.19 μs | |
| # Duration: 8.79895 ms ± 1.10 μs | |
| # Duration: 8.79923 ms ± 1.10 μs | |
| # | |
| # L4 | |
| # Duration: 15.14415 ms ± 15.64 μs | |
| # Duration: 15.16400 ms ± 15.87 μs | |
| # Duration: 15.16389 ms ± 16.61 μs | |
| # Duration: 15.16424 ms ± 17.61 μs | |
| # Duration: 15.17881 ms ± 17.44 μs | |
| # | |
| # T4 | |
| # Duration: 47.49176 ms ± 4.11 μs | |
| # Duration: 47.49174 ms ± 3.78 μs | |
| # Duration: 47.49215 ms ± 3.67 μs | |
| # Duration: 47.49260 ms ± 3.67 μs | |
| # Duration: 47.49169 ms ± 3.84 μs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment