Created
August 7, 2024 21:31
-
-
Save functionstackx/afea0b3423dd2eecaba0b61c633c96de to your computer and use it in GitHub Desktop.
H100 Mem BW
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| # Define the range of message sizes in bytes | |
| sizes = [2**i for i in range(0, 35)] # 1B to 16GB | |
| # Define the number of iterations | |
| iterations = 1000 | |
| # Function to measure memory bandwidth | |
| def measure_bandwidth(): | |
| results = [] | |
| for size in sizes: | |
| num_elements = size // 4 if size >= 4 else 1 # Number of float32 elements (at least 1 element) | |
| # Allocate tensors | |
| src = torch.ones(num_elements, dtype=torch.float32, device='cuda') | |
| dst = torch.zeros(num_elements, dtype=torch.float32, device='cuda') | |
| # Warm-up | |
| dst.copy_(src) | |
| torch.cuda.synchronize() | |
| elapsed_times = [] | |
| for _ in range(iterations): | |
| start_event = torch.cuda.Event(enable_timing=True) | |
| end_event = torch.cuda.Event(enable_timing=True) | |
| start_event.record() | |
| dst.copy_(src) | |
| end_event.record() | |
| # Wait for the copy to complete | |
| torch.cuda.synchronize() | |
| # Calculate elapsed time in milliseconds | |
| elapsed_time_ms = start_event.elapsed_time(end_event) | |
| elapsed_times.append(elapsed_time_ms) | |
| # Average elapsed time | |
| avg_elapsed_time_ms = sum(elapsed_times) / len(elapsed_times) | |
| # Calculate bandwidth in GB/s | |
| bandwidth = (size / (avg_elapsed_time_ms * 1e6)) # GB/s | |
| results.append((size, bandwidth)) | |
| print(f"Size: {size / 1e6} MB, Bandwidth: {bandwidth:.2f} GB/s") | |
| return results | |
| # Run the measurement | |
| if __name__ == "__main__": | |
| bandwidth_results = measure_bandwidth() | |
| # Convert results to a DataFrame | |
| df = pd.DataFrame(bandwidth_results, columns=['Size (bytes)', 'Bandwidth (GB/s)']) | |
| # Export to CSV | |
| df.to_csv('bandwidth_results.csv', index=False) | |
| # Plot the results | |
| plt.figure(figsize=(10, 6)) | |
| plt.plot(df['Size (bytes)'], df['Bandwidth (GB/s)'], marker='o') | |
| plt.xscale('log', base=2) | |
| plt.xticks(ticks=sizes, labels=[f"{int(size)}B" if size < 1024 else | |
| f"{int(size / 1024)}KB" if size < 1024**2 else | |
| f"{int(size / 1024**2)}MB" if size < 1024**3 else | |
| f"{int(size / 1024**3)}GB" for size in sizes], rotation=45) | |
| plt.yticks(ticks=np.arange(0, max(df['Bandwidth (GB/s)']) + 100, 100)) # Y-axis ticks every 100 GB/s | |
| plt.xlabel('Message Size') | |
| plt.ylabel('Bandwidth (GB/s)') | |
| plt.title('Memory Bandwidth vs Message Size') | |
| plt.grid(True) | |
| plt.tight_layout() | |
| plt.savefig('bandwidth_plot.png') | |
| plt.show() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Size (bytes) | Bandwidth (GB/s) | |
|---|---|---|
| 1 | 9.067511596005305e-05 | |
| 2 | 0.00018463647315995806 | |
| 4 | 0.0003685391403448731 | |
| 8 | 0.00073772644566514 | |
| 16 | 0.0014971404628611466 | |
| 32 | 0.002989822644883229 | |
| 64 | 0.005978054563879316 | |
| 128 | 0.011931536829088691 | |
| 256 | 0.023927666669290044 | |
| 512 | 0.04798133529160767 | |
| 1024 | 0.09543806059784696 | |
| 2048 | 0.19311193877093613 | |
| 4096 | 0.38321632248626125 | |
| 8192 | 0.7680675922287223 | |
| 16384 | 1.5398727818841034 | |
| 32768 | 3.083965793934213 | |
| 65536 | 6.133168825096853 | |
| 131072 | 12.380382436548839 | |
| 262144 | 24.559815085850147 | |
| 524288 | 49.16916009169341 | |
| 1048576 | 98.27402641108246 | |
| 2097152 | 196.0535605940026 | |
| 4194304 | 390.17768582305064 | |
| 8388608 | 756.7470831803522 | |
| 16777216 | 1207.1217915092502 | |
| 33554432 | 1017.3119708659103 | |
| 67108864 | 1154.3026186403176 | |
| 134217728 | 1236.9465796967484 | |
| 268435456 | 1283.5936519674017 | |
| 536870912 | 1295.281562135945 | |
| 1073741824 | 1311.8103389938299 | |
| 2147483648 | 1321.7200408294004 | |
| 4294967296 | 1327.8750549498875 | |
| 8589934592 | 1330.2828550259078 | |
| 17179869184 | 1330.6639859873212 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment